1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
#[allow(unused)]
use std::time::Duration;
use r2d2::ManageConnection;
use rocket::{Rocket, Build};
#[allow(unused_imports)]
use crate::{Config, Error};
/// Trait implemented by `r2d2`-based database adapters.
///
/// # Provided Implementations
///
/// Implementations of `Poolable` are provided for the following types:
///
/// * [`diesel::MysqlConnection`](diesel::MysqlConnection)
/// * [`diesel::PgConnection`](diesel::PgConnection)
/// * [`diesel::SqliteConnection`](diesel::SqliteConnection)
/// * [`postgres::Client`](postgres::Client)
/// * [`rusqlite::Connection`](rusqlite::Connection)
/// * [`memcache::Client`](memcache::Client)
///
/// # Implementation Guide
///
/// As an r2d2-compatible database (or other resource) adapter provider,
/// implementing `Poolable` in your own library will enable Rocket users to
/// consume your adapter with its built-in connection pooling support.
///
/// ## Example
///
/// Consider a library `foo` with the following types:
///
/// * `foo::ConnectionManager`, which implements [`r2d2::ManageConnection`]
/// * `foo::Connection`, the `Connection` associated type of
/// `foo::ConnectionManager`
/// * `foo::Error`, errors resulting from manager instantiation
///
/// In order for Rocket to generate the required code to automatically provision
/// a r2d2 connection pool into application state, the `Poolable` trait needs to
/// be implemented for the connection type. The following example implements
/// `Poolable` for `foo::Connection`:
///
/// ```rust
/// # mod foo {
/// # use std::fmt;
/// # use rocket_sync_db_pools::r2d2;
/// # #[derive(Debug)] pub struct Error;
/// # impl std::error::Error for Error { }
/// # impl fmt::Display for Error {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
/// # }
/// #
/// # pub struct Connection;
/// # pub struct ConnectionManager;
/// #
/// # type Result<T> = std::result::Result<T, Error>;
/// #
/// # impl ConnectionManager {
/// # pub fn new(url: &str) -> Result<Self> { Err(Error) }
/// # }
/// #
/// # impl self::r2d2::ManageConnection for ConnectionManager {
/// # type Connection = Connection;
/// # type Error = Error;
/// # fn connect(&self) -> Result<Connection> { panic!() }
/// # fn is_valid(&self, _: &mut Connection) -> Result<()> { panic!() }
/// # fn has_broken(&self, _: &mut Connection) -> bool { panic!() }
/// # }
/// # }
/// use std::time::Duration;
/// use rocket::{Rocket, Build};
/// use rocket_sync_db_pools::{r2d2, Error, Config, Poolable, PoolResult};
///
/// impl Poolable for foo::Connection {
/// type Manager = foo::ConnectionManager;
/// type Error = foo::Error;
///
/// fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
/// let config = Config::from(db_name, rocket)?;
/// let manager = foo::ConnectionManager::new(&config.url).map_err(Error::Custom)?;
/// Ok(r2d2::Pool::builder()
/// .max_size(config.pool_size)
/// .connection_timeout(Duration::from_secs(config.timeout as u64))
/// .build(manager)?)
/// }
/// }
/// ```
///
/// In this example, `ConnectionManager::new()` method returns a `foo::Error` on
/// failure. The [`Error`] enum consolidates this type, the `r2d2::Error` type
/// that can result from `r2d2::Pool::builder()`, and the
/// [`figment::Error`](rocket::figment::Error) type from
/// `database::Config::from()`.
///
/// In the event that a connection manager isn't fallible (as is the case with
/// Diesel's r2d2 connection manager, for instance), the associated error type
/// for the `Poolable` implementation should be `std::convert::Infallible`.
///
/// For more concrete example, consult Rocket's existing implementations of
/// [`Poolable`].
pub trait Poolable: Send + Sized + 'static {
/// The associated connection manager for the given connection type.
type Manager: ManageConnection<Connection=Self>;
/// The associated error type in the event that constructing the connection
/// manager and/or the connection pool fails.
type Error: std::fmt::Debug;
/// Creates an `r2d2` connection pool for `Manager::Connection`, returning
/// the pool on success.
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
}
/// A type alias for the return type of [`Poolable::pool()`].
#[allow(type_alias_bounds)]
pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
#[cfg(feature = "diesel_sqlite_pool")]
impl Poolable for diesel::SqliteConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
use diesel::{SqliteConnection, connection::SimpleConnection};
use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
#[derive(Debug)]
struct Customizer;
impl CustomizeConnection<SqliteConnection, Error> for Customizer {
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
conn.batch_execute("\
PRAGMA journal_mode = WAL;\
PRAGMA busy_timeout = 5000;\
PRAGMA foreign_keys = ON;\
").map_err(Error::QueryError)?;
Ok(())
}
}
let config = Config::from(db_name, rocket)?;
let manager = ConnectionManager::new(&config.url);
let pool = Pool::builder()
.connection_customizer(Box::new(Customizer))
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "diesel_postgres_pool")]
impl Poolable for diesel::PgConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = diesel::r2d2::ConnectionManager::new(&config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "diesel_mysql_pool")]
impl Poolable for diesel::MysqlConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = diesel::r2d2::ConnectionManager::new(&config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
type Error = postgres::Error;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let url = config.url.parse().map_err(Error::Custom)?;
let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "sqlite_pool")]
impl Poolable for rusqlite::Connection {
type Manager = r2d2_sqlite::SqliteConnectionManager;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
use rocket::figment::providers::Serialized;
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")]
enum OpenFlag {
ReadOnly,
ReadWrite,
Create,
Uri,
Memory,
NoMutex,
FullMutex,
SharedCache,
PrivateCache,
Nofollow,
}
let figment = Config::figment(db_name, rocket);
let config: Config = figment.extract()?;
let open_flags: Vec<OpenFlag> = figment
.join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
.extract_inner("open_flags")?;
let mut flags = rusqlite::OpenFlags::default();
for flag in open_flags {
let sql_flag = match flag {
OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
};
flags.insert(sql_flag)
};
let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
.with_flags(flags);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "memcache_pool")]
mod memcache_pool {
use memcache::{Client, Connectable, MemcacheError};
use super::*;
#[derive(Debug)]
pub struct ConnectionManager {
urls: Vec<String>,
}
impl ConnectionManager {
pub fn new<C: Connectable>(target: C) -> Self {
Self { urls: target.get_urls(), }
}
}
impl r2d2::ManageConnection for ConnectionManager {
type Connection = Client;
type Error = MemcacheError;
fn connect(&self) -> Result<Client, MemcacheError> {
Client::connect(self.urls.clone())
}
fn is_valid(&self, connection: &mut Client) -> Result<(), MemcacheError> {
connection.version().map(|_| ())
}
fn has_broken(&self, _connection: &mut Client) -> bool {
false
}
}
impl super::Poolable for memcache::Client {
type Manager = ConnectionManager;
type Error = MemcacheError;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = ConnectionManager::new(&*config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
}