rocket_sync_db_pools/
poolable.rs

1#[allow(unused)]
2use std::time::Duration;
3
4use r2d2::ManageConnection;
5use rocket::{Rocket, Build};
6
7#[allow(unused_imports)]
8use crate::{Config, Error};
9
10/// Trait implemented by `r2d2`-based database adapters.
11///
12/// # Provided Implementations
13///
14/// Implementations of `Poolable` are provided for the following types:
15///
16///   * `diesel::MysqlConnection`
17///   * `diesel::PgConnection`
18///   * `diesel::SqliteConnection`
19///   * `postgres::Connection`
20///   * `rusqlite::Connection`
21///
22/// # Implementation Guide
23///
24/// As an r2d2-compatible database (or other resource) adapter provider,
25/// implementing `Poolable` in your own library will enable Rocket users to
26/// consume your adapter with its built-in connection pooling support.
27///
28/// ## Example
29///
30/// Consider a library `foo` with the following types:
31///
32///   * `foo::ConnectionManager`, which implements [`r2d2::ManageConnection`]
33///   * `foo::Connection`, the `Connection` associated type of
34///     `foo::ConnectionManager`
35///   * `foo::Error`, errors resulting from manager instantiation
36///
37/// In order for Rocket to generate the required code to automatically provision
38/// a r2d2 connection pool into application state, the `Poolable` trait needs to
39/// be implemented for the connection type. The following example implements
40/// `Poolable` for `foo::Connection`:
41///
42/// ```rust
43/// # mod foo {
44/// #     use std::fmt;
45/// #     use rocket_sync_db_pools::r2d2;
46/// #     #[derive(Debug)] pub struct Error;
47/// #     impl std::error::Error for Error {  }
48/// #     impl fmt::Display for Error {
49/// #         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
50/// #     }
51/// #
52/// #     pub struct Connection;
53/// #     pub struct ConnectionManager;
54/// #
55/// #     type Result<T> = std::result::Result<T, Error>;
56/// #
57/// #     impl ConnectionManager {
58/// #         pub fn new(url: &str) -> Result<Self> { Err(Error) }
59/// #     }
60/// #
61/// #     impl self::r2d2::ManageConnection for ConnectionManager {
62/// #          type Connection = Connection;
63/// #          type Error = Error;
64/// #          fn connect(&self) -> Result<Connection> { panic!() }
65/// #          fn is_valid(&self, _: &mut Connection) -> Result<()> { panic!() }
66/// #          fn has_broken(&self, _: &mut Connection) -> bool { panic!() }
67/// #     }
68/// # }
69/// use std::time::Duration;
70/// use rocket::{Rocket, Build};
71/// use rocket_sync_db_pools::{r2d2, Error, Config, Poolable, PoolResult};
72///
73/// impl Poolable for foo::Connection {
74///     type Manager = foo::ConnectionManager;
75///     type Error = foo::Error;
76///
77///     fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
78///         let config = Config::from(db_name, rocket)?;
79///         let manager = foo::ConnectionManager::new(&config.url).map_err(Error::Custom)?;
80///         Ok(r2d2::Pool::builder()
81///             .max_size(config.pool_size)
82///             .connection_timeout(Duration::from_secs(config.timeout as u64))
83///             .build(manager)?)
84///     }
85/// }
86/// ```
87///
88/// In this example, `ConnectionManager::new()` method returns a `foo::Error` on
89/// failure. The [`Error`] enum consolidates this type, the `r2d2::Error` type
90/// that can result from `r2d2::Pool::builder()`, and the
91/// [`figment::Error`](rocket::figment::Error) type from
92/// `database::Config::from()`.
93///
94/// In the event that a connection manager isn't fallible (as is the case with
95/// Diesel's r2d2 connection manager, for instance), the associated error type
96/// for the `Poolable` implementation should be `std::convert::Infallible`.
97///
98/// For more concrete example, consult Rocket's existing implementations of
99/// [`Poolable`].
100pub trait Poolable: Send + Sized + 'static {
101    /// The associated connection manager for the given connection type.
102    type Manager: ManageConnection<Connection=Self>;
103
104    /// The associated error type in the event that constructing the connection
105    /// manager and/or the connection pool fails.
106    type Error: std::fmt::Debug;
107
108    /// Creates an `r2d2` connection pool for `Manager::Connection`, returning
109    /// the pool on success.
110    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
111}
112
113/// A type alias for the return type of [`Poolable::pool()`].
114#[allow(type_alias_bounds)]
115pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
116
117#[cfg(feature = "diesel_sqlite_pool")]
118impl Poolable for diesel::SqliteConnection {
119    type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
120    type Error = std::convert::Infallible;
121
122    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
123        use diesel::{SqliteConnection, connection::SimpleConnection};
124        use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
125
126        #[derive(Debug)]
127        struct Customizer;
128
129        impl CustomizeConnection<SqliteConnection, Error> for Customizer {
130            fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
131                conn.batch_execute("\
132                    PRAGMA journal_mode = WAL;\
133                    PRAGMA busy_timeout = 1000;\
134                    PRAGMA foreign_keys = ON;\
135                ").map_err(Error::QueryError)?;
136
137                Ok(())
138            }
139        }
140
141        let config = Config::from(db_name, rocket)?;
142        let manager = ConnectionManager::new(&config.url);
143        let pool = Pool::builder()
144            .connection_customizer(Box::new(Customizer))
145            .max_size(config.pool_size)
146            .connection_timeout(Duration::from_secs(config.timeout as u64))
147            .build(manager)?;
148
149        Ok(pool)
150    }
151}
152
153#[cfg(feature = "diesel_postgres_pool")]
154impl Poolable for diesel::PgConnection {
155    type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
156    type Error = std::convert::Infallible;
157
158    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
159        let config = Config::from(db_name, rocket)?;
160        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
161        let pool = r2d2::Pool::builder()
162            .max_size(config.pool_size)
163            .connection_timeout(Duration::from_secs(config.timeout as u64))
164            .build(manager)?;
165
166        Ok(pool)
167    }
168}
169
170#[cfg(feature = "diesel_mysql_pool")]
171impl Poolable for diesel::MysqlConnection {
172    type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
173    type Error = std::convert::Infallible;
174
175    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
176        let config = Config::from(db_name, rocket)?;
177        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
178        let pool = r2d2::Pool::builder()
179            .max_size(config.pool_size)
180            .connection_timeout(Duration::from_secs(config.timeout as u64))
181            .build(manager)?;
182
183        Ok(pool)
184    }
185}
186
187// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
188#[cfg(feature = "postgres_pool")]
189impl Poolable for postgres::Client {
190    type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
191    type Error = postgres::Error;
192
193    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
194        let config = Config::from(db_name, rocket)?;
195        let url = config.url.parse().map_err(Error::Custom)?;
196        let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
197        let pool = r2d2::Pool::builder()
198            .max_size(config.pool_size)
199            .connection_timeout(Duration::from_secs(config.timeout as u64))
200            .build(manager)?;
201
202        Ok(pool)
203    }
204}
205
206#[cfg(feature = "sqlite_pool")]
207impl Poolable for rusqlite::Connection {
208    type Manager = r2d2_sqlite::SqliteConnectionManager;
209    type Error = std::convert::Infallible;
210
211    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
212        use rocket::figment::providers::Serialized;
213
214        #[derive(Debug, serde::Deserialize, serde::Serialize)]
215        #[serde(rename_all = "snake_case")]
216        enum OpenFlag {
217            ReadOnly,
218            ReadWrite,
219            Create,
220            Uri,
221            Memory,
222            NoMutex,
223            FullMutex,
224            SharedCache,
225            PrivateCache,
226            Nofollow,
227        }
228
229        let figment = Config::figment(db_name, rocket);
230        let config: Config = figment.extract()?;
231        let open_flags: Vec<OpenFlag> = figment
232            .join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
233            .extract_inner("open_flags")?;
234
235        let mut flags = rusqlite::OpenFlags::default();
236        for flag in open_flags {
237            let sql_flag = match flag {
238                OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
239                OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
240                OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
241                OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
242                OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
243                OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
244                OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
245                OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
246                OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
247                OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
248            };
249
250            flags.insert(sql_flag)
251        };
252
253        let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
254            .with_flags(flags);
255
256        let pool = r2d2::Pool::builder()
257            .max_size(config.pool_size)
258            .connection_timeout(Duration::from_secs(config.timeout as u64))
259            .build(manager)?;
260
261        Ok(pool)
262    }
263}
264
265#[cfg(feature = "memcache_pool")]
266impl Poolable for memcache::Client {
267    type Manager = r2d2_memcache::MemcacheConnectionManager;
268    // Unused, but we might want it in the future without a breaking change.
269    type Error = memcache::MemcacheError;
270
271    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
272        let config = Config::from(db_name, rocket)?;
273        let manager = r2d2_memcache::MemcacheConnectionManager::new(&*config.url);
274        let pool = r2d2::Pool::builder()
275            .max_size(config.pool_size)
276            .connection_timeout(Duration::from_secs(config.timeout as u64))
277            .build(manager)?;
278
279        Ok(pool)
280    }
281}