rocket_sync_db_pools/
poolable.rs

1#[allow(unused)]
2use std::time::Duration;
3
4use r2d2::ManageConnection;
5use rocket::{Rocket, Build};
6
7#[allow(unused_imports)]
8use crate::{Config, Error};
9
10/// Trait implemented by `r2d2`-based database adapters.
11///
12/// # Provided Implementations
13///
14/// Implementations of `Poolable` are provided for the following types:
15///
16///   * [`diesel::MysqlConnection`](diesel::MysqlConnection)
17///   * [`diesel::PgConnection`](diesel::PgConnection)
18///   * [`diesel::SqliteConnection`](diesel::SqliteConnection)
19///   * [`postgres::Client`](postgres::Client)
20///   * [`rusqlite::Connection`](rusqlite::Connection)
21///   * [`memcache::Client`](memcache::Client)
22///
23/// # Implementation Guide
24///
25/// As an r2d2-compatible database (or other resource) adapter provider,
26/// implementing `Poolable` in your own library will enable Rocket users to
27/// consume your adapter with its built-in connection pooling support.
28///
29/// ## Example
30///
31/// Consider a library `foo` with the following types:
32///
33///   * `foo::ConnectionManager`, which implements [`r2d2::ManageConnection`]
34///   * `foo::Connection`, the `Connection` associated type of
35///     `foo::ConnectionManager`
36///   * `foo::Error`, errors resulting from manager instantiation
37///
38/// In order for Rocket to generate the required code to automatically provision
39/// a r2d2 connection pool into application state, the `Poolable` trait needs to
40/// be implemented for the connection type. The following example implements
41/// `Poolable` for `foo::Connection`:
42///
43/// ```rust
44/// # mod foo {
45/// #     use std::fmt;
46/// #     use rocket_sync_db_pools::r2d2;
47/// #     #[derive(Debug)] pub struct Error;
48/// #     impl std::error::Error for Error {  }
49/// #     impl fmt::Display for Error {
50/// #         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
51/// #     }
52/// #
53/// #     pub struct Connection;
54/// #     pub struct ConnectionManager;
55/// #
56/// #     type Result<T> = std::result::Result<T, Error>;
57/// #
58/// #     impl ConnectionManager {
59/// #         pub fn new(url: &str) -> Result<Self> { Err(Error) }
60/// #     }
61/// #
62/// #     impl self::r2d2::ManageConnection for ConnectionManager {
63/// #          type Connection = Connection;
64/// #          type Error = Error;
65/// #          fn connect(&self) -> Result<Connection> { panic!() }
66/// #          fn is_valid(&self, _: &mut Connection) -> Result<()> { panic!() }
67/// #          fn has_broken(&self, _: &mut Connection) -> bool { panic!() }
68/// #     }
69/// # }
70/// use std::time::Duration;
71/// use rocket::{Rocket, Build};
72/// use rocket_sync_db_pools::{r2d2, Error, Config, Poolable, PoolResult};
73///
74/// impl Poolable for foo::Connection {
75///     type Manager = foo::ConnectionManager;
76///     type Error = foo::Error;
77///
78///     fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
79///         let config = Config::from(db_name, rocket)?;
80///         let manager = foo::ConnectionManager::new(&config.url).map_err(Error::Custom)?;
81///         Ok(r2d2::Pool::builder()
82///             .max_size(config.pool_size)
83///             .connection_timeout(Duration::from_secs(config.timeout as u64))
84///             .build(manager)?)
85///     }
86/// }
87/// ```
88///
89/// In this example, `ConnectionManager::new()` method returns a `foo::Error` on
90/// failure. The [`Error`] enum consolidates this type, the `r2d2::Error` type
91/// that can result from `r2d2::Pool::builder()`, and the
92/// [`figment::Error`](rocket::figment::Error) type from
93/// `database::Config::from()`.
94///
95/// In the event that a connection manager isn't fallible (as is the case with
96/// Diesel's r2d2 connection manager, for instance), the associated error type
97/// for the `Poolable` implementation should be `std::convert::Infallible`.
98///
99/// For more concrete example, consult Rocket's existing implementations of
100/// [`Poolable`].
101pub trait Poolable: Send + Sized + 'static {
102    /// The associated connection manager for the given connection type.
103    type Manager: ManageConnection<Connection=Self>;
104
105    /// The associated error type in the event that constructing the connection
106    /// manager and/or the connection pool fails.
107    type Error: std::fmt::Debug;
108
109    /// Creates an `r2d2` connection pool for `Manager::Connection`, returning
110    /// the pool on success.
111    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
112}
113
114/// A type alias for the return type of [`Poolable::pool()`].
115#[allow(type_alias_bounds)]
116pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
117
118#[cfg(feature = "diesel_sqlite_pool")]
119impl Poolable for diesel::SqliteConnection {
120    type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
121    type Error = std::convert::Infallible;
122
123    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
124        use diesel::{SqliteConnection, connection::SimpleConnection};
125        use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
126
127        #[derive(Debug)]
128        struct Customizer;
129
130        impl CustomizeConnection<SqliteConnection, Error> for Customizer {
131            fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
132                conn.batch_execute("\
133                    PRAGMA journal_mode = WAL;\
134                    PRAGMA busy_timeout = 5000;\
135                    PRAGMA foreign_keys = ON;\
136                ").map_err(Error::QueryError)?;
137
138                Ok(())
139            }
140        }
141
142        let config = Config::from(db_name, rocket)?;
143        let manager = ConnectionManager::new(&config.url);
144        let pool = Pool::builder()
145            .connection_customizer(Box::new(Customizer))
146            .max_size(config.pool_size)
147            .connection_timeout(Duration::from_secs(config.timeout as u64))
148            .build(manager)?;
149
150        Ok(pool)
151    }
152}
153
154#[cfg(feature = "diesel_postgres_pool")]
155impl Poolable for diesel::PgConnection {
156    type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
157    type Error = std::convert::Infallible;
158
159    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
160        let config = Config::from(db_name, rocket)?;
161        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
162        let pool = r2d2::Pool::builder()
163            .max_size(config.pool_size)
164            .connection_timeout(Duration::from_secs(config.timeout as u64))
165            .build(manager)?;
166
167        Ok(pool)
168    }
169}
170
171#[cfg(feature = "diesel_mysql_pool")]
172impl Poolable for diesel::MysqlConnection {
173    type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
174    type Error = std::convert::Infallible;
175
176    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
177        let config = Config::from(db_name, rocket)?;
178        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
179        let pool = r2d2::Pool::builder()
180            .max_size(config.pool_size)
181            .connection_timeout(Duration::from_secs(config.timeout as u64))
182            .build(manager)?;
183
184        Ok(pool)
185    }
186}
187
188// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
189#[cfg(feature = "postgres_pool")]
190impl Poolable for postgres::Client {
191    type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
192    type Error = postgres::Error;
193
194    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
195        let config = Config::from(db_name, rocket)?;
196        let url = config.url.parse().map_err(Error::Custom)?;
197        let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
198        let pool = r2d2::Pool::builder()
199            .max_size(config.pool_size)
200            .connection_timeout(Duration::from_secs(config.timeout as u64))
201            .build(manager)?;
202
203        Ok(pool)
204    }
205}
206
207#[cfg(feature = "sqlite_pool")]
208impl Poolable for rusqlite::Connection {
209    type Manager = r2d2_sqlite::SqliteConnectionManager;
210    type Error = std::convert::Infallible;
211
212    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
213        use rocket::figment::providers::Serialized;
214
215        #[derive(Debug, serde::Deserialize, serde::Serialize)]
216        #[serde(rename_all = "snake_case")]
217        enum OpenFlag {
218            ReadOnly,
219            ReadWrite,
220            Create,
221            Uri,
222            Memory,
223            NoMutex,
224            FullMutex,
225            SharedCache,
226            PrivateCache,
227            Nofollow,
228        }
229
230        let figment = Config::figment(db_name, rocket);
231        let config: Config = figment.extract()?;
232        let open_flags: Vec<OpenFlag> = figment
233            .join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
234            .extract_inner("open_flags")?;
235
236        let mut flags = rusqlite::OpenFlags::default();
237        for flag in open_flags {
238            let sql_flag = match flag {
239                OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
240                OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
241                OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
242                OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
243                OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
244                OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
245                OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
246                OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
247                OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
248                OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
249            };
250
251            flags.insert(sql_flag)
252        };
253
254        let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
255            .with_flags(flags);
256
257        let pool = r2d2::Pool::builder()
258            .max_size(config.pool_size)
259            .connection_timeout(Duration::from_secs(config.timeout as u64))
260            .build(manager)?;
261
262        Ok(pool)
263    }
264}
265
266#[cfg(feature = "memcache_pool")]
267mod memcache_pool {
268    use memcache::{Client, Connectable, MemcacheError};
269
270    use super::*;
271
272    #[derive(Debug)]
273    pub struct ConnectionManager {
274        urls: Vec<String>,
275    }
276
277    impl ConnectionManager {
278        pub fn new<C: Connectable>(target: C) -> Self {
279            Self { urls: target.get_urls(), }
280        }
281    }
282
283    impl r2d2::ManageConnection for ConnectionManager {
284        type Connection = Client;
285        type Error = MemcacheError;
286
287        fn connect(&self) -> Result<Client, MemcacheError> {
288            Client::connect(self.urls.clone())
289        }
290
291        fn is_valid(&self, connection: &mut Client) -> Result<(), MemcacheError> {
292            connection.version().map(|_| ())
293        }
294
295        fn has_broken(&self, _connection: &mut Client) -> bool {
296            false
297        }
298    }
299
300    impl super::Poolable for memcache::Client {
301        type Manager = ConnectionManager;
302        type Error = MemcacheError;
303
304        fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
305            let config = Config::from(db_name, rocket)?;
306            let manager = ConnectionManager::new(&*config.url);
307            let pool = r2d2::Pool::builder()
308                .max_size(config.pool_size)
309                .connection_timeout(Duration::from_secs(config.timeout as u64))
310                .build(manager)?;
311
312            Ok(pool)
313        }
314    }
315}