rocket_db_pools/
pool.rs

1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {std::time::Duration, crate::{Error, Config}};
5
6/// Generic [`Database`](crate::Database) driver connection pool trait.
7///
8/// This trait provides a generic interface to various database pooling
9/// implementations in the Rust ecosystem. It can be implemented by anyone, but
10/// this crate provides implementations for common drivers.
11///
12/// **Implementations of this trait outside of this crate should be rare. You
13/// _do not_ need to implement this trait or understand its specifics to use
14/// this crate.**
15///
16/// ## Async Trait
17///
18/// [`Pool`] is an _async_ trait. Implementations of `Pool` must be decorated
19/// with an attribute of `#[async_trait]`:
20///
21/// ```rust
22/// # #[macro_use] extern crate rocket;
23/// use rocket::figment::Figment;
24/// use rocket_db_pools::Pool;
25///
26/// # struct MyPool;
27/// # type Connection = ();
28/// # type Error = std::convert::Infallible;
29/// #[rocket::async_trait]
30/// impl Pool for MyPool {
31///     type Connection = Connection;
32///
33///     type Error = Error;
34///
35///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
36///         todo!("initialize and return an instance of the pool");
37///     }
38///
39///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
40///         todo!("fetch one connection from the pool");
41///     }
42///
43///     async fn close(&self) {
44///         todo!("gracefully shutdown connection pool");
45///     }
46/// }
47/// ```
48///
49/// ## Implementing
50///
51/// Implementations of `Pool` typically trace the following outline:
52///
53///   1. The `Error` associated type is set to [`Error`].
54///
55///   2. A [`Config`] is [extracted](Figment::extract()) from the `figment`
56///      passed to init.
57///
58///   3. The pool is initialized and returned in `init()`, wrapping
59///      initialization errors in [`Error::Init`].
60///
61///   4. A connection is retrieved in `get()`, wrapping errors in
62///      [`Error::Get`].
63///
64/// Concretely, this looks like:
65///
66/// ```rust
67/// use rocket::figment::Figment;
68/// use rocket_db_pools::{Pool, Config, Error};
69/// #
70/// # type InitError = std::convert::Infallible;
71/// # type GetError = std::convert::Infallible;
72/// # type Connection = ();
73/// #
74/// # struct MyPool(Config);
75/// # impl MyPool {
76/// #    fn new(c: Config) -> Result<Self, InitError> {
77/// #        Ok(Self(c))
78/// #    }
79/// #
80/// #    fn acquire(&self) -> Result<Connection, GetError> {
81/// #        Ok(())
82/// #    }
83/// #
84/// #   async fn shutdown(&self) { }
85/// # }
86///
87/// #[rocket::async_trait]
88/// impl Pool for MyPool {
89///     type Connection = Connection;
90///
91///     type Error = Error<InitError, GetError>;
92///
93///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
94///         // Extract the config from `figment`.
95///         let config: Config = figment.extract()?;
96///
97///         // Read config values, initialize `MyPool`. Map errors of type
98///         // `InitError` to `Error<InitError, _>` with `Error::Init`.
99///         let pool = MyPool::new(config).map_err(Error::Init)?;
100///
101///         // Return the fully initialized pool.
102///         Ok(pool)
103///     }
104///
105///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
106///         // Get one connection from the pool, here via an `acquire()` method.
107///         // Map errors of type `GetError` to `Error<_, GetError>`.
108///         self.acquire().map_err(Error::Get)
109///     }
110///
111///     async fn close(&self) {
112///         self.shutdown().await;
113///     }
114/// }
115/// ```
116#[rocket::async_trait]
117pub trait Pool: Sized + Send + Sync + 'static {
118    /// The connection type managed by this pool, returned by [`Self::get()`].
119    type Connection;
120
121    /// The error type returned by [`Self::init()`] and [`Self::get()`].
122    type Error: std::error::Error;
123
124    /// Constructs a pool from a [Value](rocket::figment::value::Value).
125    ///
126    /// It is up to each implementor of `Pool` to define its accepted
127    /// configuration value(s) via the `Config` associated type.  Most
128    /// integrations provided in `rocket_db_pools` use [`Config`], which
129    /// accepts a (required) `url` and an (optional) `pool_size`.
130    ///
131    /// ## Errors
132    ///
133    /// This method returns an error if the configuration is not compatible, or
134    /// if creating a pool failed due to an unavailable database server,
135    /// insufficient resources, or another database-specific error.
136    async fn init(figment: &Figment) -> Result<Self, Self::Error>;
137
138    /// Asynchronously retrieves a connection from the factory or pool.
139    ///
140    /// ## Errors
141    ///
142    /// This method returns an error if a connection could not be retrieved,
143    /// such as a preconfigured timeout elapsing or when the database server is
144    /// unavailable.
145    async fn get(&self) -> Result<Self::Connection, Self::Error>;
146
147    /// Shutdown the connection pool, disallowing any new connections from being
148    /// retrieved and waking up any tasks with active connections.
149    ///
150    /// The returned future may either resolve when all connections are known to
151    /// have closed or at any point prior. Details are implementation specific.
152    async fn close(&self);
153}
154
155#[cfg(feature = "deadpool")]
156mod deadpool_postgres {
157    use deadpool::{Runtime, managed::{Manager, Pool, PoolError, Object}};
158    use super::{Duration, Error, Config, Figment};
159
160    pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
161        fn new(config: &Config) -> Result<Self, Self::Error>;
162    }
163
164    #[cfg(feature = "deadpool_postgres")]
165    impl DeadManager for deadpool_postgres::Manager {
166        fn new(config: &Config) -> Result<Self, Self::Error> {
167            Ok(Self::new(config.url.parse()?, deadpool_postgres::tokio_postgres::NoTls))
168        }
169    }
170
171    #[cfg(feature = "deadpool_redis")]
172    impl DeadManager for deadpool_redis::Manager {
173        fn new(config: &Config) -> Result<Self, Self::Error> {
174            Self::new(config.url.as_str())
175        }
176    }
177
178    #[rocket::async_trait]
179    impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
180        where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
181    {
182        type Error = Error<PoolError<M::Error>>;
183
184        type Connection = C;
185
186        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
187            let config: Config = figment.extract()?;
188            let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
189
190            Pool::builder(manager)
191                .max_size(config.max_connections)
192                .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
193                .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
194                .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
195                .runtime(Runtime::Tokio1)
196                .build()
197                .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
198        }
199
200        async fn get(&self) -> Result<Self::Connection, Self::Error> {
201            self.get().await.map_err(Error::Get)
202        }
203
204        async fn close(&self) {
205            <Pool<M, C>>::close(self)
206        }
207    }
208}
209
210// TODO: Remove when new release of diesel-async with deadpool 0.10 is out.
211#[cfg(all(feature = "deadpool_09", any(feature = "diesel_postgres", feature = "diesel_mysql")))]
212mod deadpool_old {
213    use deadpool_09::{managed::{Manager, Pool, PoolError, Object, BuildError}, Runtime};
214    use diesel_async::pooled_connection::AsyncDieselConnectionManager;
215
216    use super::{Duration, Error, Config, Figment};
217
218    pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
219        fn new(config: &Config) -> Result<Self, Self::Error>;
220    }
221
222    #[cfg(feature = "diesel_postgres")]
223    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
224        fn new(config: &Config) -> Result<Self, Self::Error> {
225            Ok(Self::new(config.url.as_str()))
226        }
227    }
228
229    #[cfg(feature = "diesel_mysql")]
230    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
231        fn new(config: &Config) -> Result<Self, Self::Error> {
232            Ok(Self::new(config.url.as_str()))
233        }
234    }
235
236    #[rocket::async_trait]
237    impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
238        where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
239    {
240        type Error = Error<BuildError<M::Error>, PoolError<M::Error>>;
241
242        type Connection = C;
243
244        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
245            let config: Config = figment.extract()?;
246            let manager = M::new(&config).map_err(|e| Error::Init(BuildError::Backend(e)))?;
247
248            Pool::builder(manager)
249                .max_size(config.max_connections)
250                .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
251                .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
252                .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
253                .runtime(Runtime::Tokio1)
254                .build()
255                .map_err(Error::Init)
256        }
257
258        async fn get(&self) -> Result<Self::Connection, Self::Error> {
259            self.get().await.map_err(Error::Get)
260        }
261
262        async fn close(&self) {
263            <Pool<M, C>>::close(self)
264        }
265    }
266}
267
268#[cfg(feature = "sqlx")]
269mod sqlx {
270    use sqlx::ConnectOptions;
271    use super::{Duration, Error, Config, Figment};
272    use rocket::config::LogLevel;
273
274    type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
275
276    // Provide specialized configuration for particular databases.
277    fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
278        #[cfg(feature = "sqlx_sqlite")]
279        if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
280            *o = std::mem::take(o)
281                .busy_timeout(Duration::from_secs(__config.connect_timeout))
282                .create_if_missing(true);
283
284            if let Some(ref exts) = __config.extensions {
285                for ext in exts {
286                    *o = std::mem::take(o).extension(ext.clone());
287                }
288            }
289        }
290    }
291
292    #[rocket::async_trait]
293    impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
294        type Error = Error<sqlx::Error>;
295
296        type Connection = sqlx::pool::PoolConnection<D>;
297
298        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
299            let config = figment.extract::<Config>()?;
300            let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
301            specialize(&mut opts, &config);
302
303            opts = opts.disable_statement_logging();
304            if let Ok(level) = figment.extract_inner::<LogLevel>(rocket::Config::LOG_LEVEL) {
305                if !matches!(level, LogLevel::Normal | LogLevel::Off) {
306                    opts = opts.log_statements(level.into())
307                        .log_slow_statements(level.into(), Duration::default());
308                }
309            }
310
311            sqlx::pool::PoolOptions::new()
312                .max_connections(config.max_connections as u32)
313                .acquire_timeout(Duration::from_secs(config.connect_timeout))
314                .idle_timeout(config.idle_timeout.map(Duration::from_secs))
315                .min_connections(config.min_connections.unwrap_or_default())
316                .connect_with(opts)
317                .await
318                .map_err(Error::Init)
319        }
320
321        async fn get(&self) -> Result<Self::Connection, Self::Error> {
322            self.acquire().await.map_err(Error::Get)
323        }
324
325        async fn close(&self) {
326            <sqlx::Pool<D>>::close(self).await;
327        }
328    }
329}
330
331#[cfg(feature = "mongodb")]
332mod mongodb {
333    use mongodb::{Client, options::ClientOptions};
334    use super::{Duration, Error, Config, Figment};
335
336    #[rocket::async_trait]
337    impl crate::Pool for Client {
338        type Error = Error<mongodb::error::Error, std::convert::Infallible>;
339
340        type Connection = Client;
341
342        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
343            let config = figment.extract::<Config>()?;
344            let mut opts = ClientOptions::parse(&config.url).await.map_err(Error::Init)?;
345            opts.min_pool_size = config.min_connections;
346            opts.max_pool_size = Some(config.max_connections as u32);
347            opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
348            opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
349            opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
350            Client::with_options(opts).map_err(Error::Init)
351        }
352
353        async fn get(&self) -> Result<Self::Connection, Self::Error> {
354            Ok(self.clone())
355        }
356
357        async fn close(&self) {
358            // nothing to do for mongodb
359        }
360    }
361}