rocket_db_pools/
pool.rs

1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {std::time::Duration, crate::{Error, Config}};
5
6/// Generic [`Database`](crate::Database) driver connection pool trait.
7///
8/// This trait provides a generic interface to various database pooling
9/// implementations in the Rust ecosystem. It can be implemented by anyone, but
10/// this crate provides implementations for common drivers.
11///
12/// **Implementations of this trait outside of this crate should be rare. You
13/// _do not_ need to implement this trait or understand its specifics to use
14/// this crate.**
15///
16/// ## Async Trait
17///
18/// [`Pool`] is an _async_ trait. Implementations of `Pool` must be decorated
19/// with an attribute of `#[async_trait]`:
20///
21/// ```rust
22/// # #[macro_use] extern crate rocket;
23/// use rocket::figment::Figment;
24/// use rocket_db_pools::Pool;
25///
26/// # struct MyPool;
27/// # type Connection = ();
28/// # type Error = std::convert::Infallible;
29/// #[rocket::async_trait]
30/// impl Pool for MyPool {
31///     type Connection = Connection;
32///
33///     type Error = Error;
34///
35///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
36///         todo!("initialize and return an instance of the pool");
37///     }
38///
39///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
40///         todo!("fetch one connection from the pool");
41///     }
42///
43///     async fn close(&self) {
44///         todo!("gracefully shutdown connection pool");
45///     }
46/// }
47/// ```
48///
49/// ## Implementing
50///
51/// Implementations of `Pool` typically trace the following outline:
52///
53///   1. The `Error` associated type is set to [`Error`].
54///
55///   2. A [`Config`] is [extracted](Figment::extract()) from the `figment`
56///      passed to init.
57///
58///   3. The pool is initialized and returned in `init()`, wrapping
59///      initialization errors in [`Error::Init`].
60///
61///   4. A connection is retrieved in `get()`, wrapping errors in
62///      [`Error::Get`].
63///
64/// Concretely, this looks like:
65///
66/// ```rust
67/// use rocket::figment::Figment;
68/// use rocket_db_pools::{Pool, Config, Error};
69/// #
70/// # type InitError = std::convert::Infallible;
71/// # type GetError = std::convert::Infallible;
72/// # type Connection = ();
73/// #
74/// # struct MyPool(Config);
75/// # impl MyPool {
76/// #    fn new(c: Config) -> Result<Self, InitError> {
77/// #        Ok(Self(c))
78/// #    }
79/// #
80/// #    fn acquire(&self) -> Result<Connection, GetError> {
81/// #        Ok(())
82/// #    }
83/// #
84/// #   async fn shutdown(&self) { }
85/// # }
86///
87/// #[rocket::async_trait]
88/// impl Pool for MyPool {
89///     type Connection = Connection;
90///
91///     type Error = Error<InitError, GetError>;
92///
93///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
94///         // Extract the config from `figment`.
95///         let config: Config = figment.extract()?;
96///
97///         // Read config values, initialize `MyPool`. Map errors of type
98///         // `InitError` to `Error<InitError, _>` with `Error::Init`.
99///         let pool = MyPool::new(config).map_err(Error::Init)?;
100///
101///         // Return the fully initialized pool.
102///         Ok(pool)
103///     }
104///
105///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
106///         // Get one connection from the pool, here via an `acquire()` method.
107///         // Map errors of type `GetError` to `Error<_, GetError>`.
108///         self.acquire().map_err(Error::Get)
109///     }
110///
111///     async fn close(&self) {
112///         self.shutdown().await;
113///     }
114/// }
115/// ```
116#[rocket::async_trait]
117pub trait Pool: Sized + Send + Sync + 'static {
118    /// The connection type managed by this pool, returned by [`Self::get()`].
119    type Connection;
120
121    /// The error type returned by [`Self::init()`] and [`Self::get()`].
122    type Error: std::error::Error;
123
124    /// Constructs a pool from a [Value](rocket::figment::value::Value).
125    ///
126    /// It is up to each implementor of `Pool` to define its accepted
127    /// configuration value(s) via the `Config` associated type.  Most
128    /// integrations provided in `rocket_db_pools` use [`Config`], which
129    /// accepts a (required) `url` and an (optional) `pool_size`.
130    ///
131    /// ## Errors
132    ///
133    /// This method returns an error if the configuration is not compatible, or
134    /// if creating a pool failed due to an unavailable database server,
135    /// insufficient resources, or another database-specific error.
136    async fn init(figment: &Figment) -> Result<Self, Self::Error>;
137
138    /// Asynchronously retrieves a connection from the factory or pool.
139    ///
140    /// ## Errors
141    ///
142    /// This method returns an error if a connection could not be retrieved,
143    /// such as a preconfigured timeout elapsing or when the database server is
144    /// unavailable.
145    async fn get(&self) -> Result<Self::Connection, Self::Error>;
146
147    /// Shutdown the connection pool, disallowing any new connections from being
148    /// retrieved and waking up any tasks with active connections.
149    ///
150    /// The returned future may either resolve when all connections are known to
151    /// have closed or at any point prior. Details are implementation specific.
152    async fn close(&self);
153}
154
155#[cfg(feature = "deadpool")]
156mod deadpool_postgres {
157    use deadpool::{Runtime, managed::{Manager, Pool, PoolError, Object}};
158    use super::{Duration, Error, Config, Figment};
159
160    #[cfg(feature = "diesel")]
161    use diesel_async::pooled_connection::AsyncDieselConnectionManager;
162
163    pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
164        fn new(config: &Config) -> Result<Self, Self::Error>;
165    }
166
167    #[cfg(feature = "deadpool_postgres")]
168    impl DeadManager for deadpool_postgres::Manager {
169        fn new(config: &Config) -> Result<Self, Self::Error> {
170            Ok(Self::new(config.url.parse()?, deadpool_postgres::tokio_postgres::NoTls))
171        }
172    }
173
174    #[cfg(feature = "deadpool_redis")]
175    impl DeadManager for deadpool_redis::Manager {
176        fn new(config: &Config) -> Result<Self, Self::Error> {
177            Self::new(config.url.as_str())
178        }
179    }
180
181    #[cfg(feature = "diesel_postgres")]
182    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
183        fn new(config: &Config) -> Result<Self, Self::Error> {
184            Ok(Self::new(config.url.as_str()))
185        }
186    }
187
188    #[cfg(feature = "diesel_mysql")]
189    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
190        fn new(config: &Config) -> Result<Self, Self::Error> {
191            Ok(Self::new(config.url.as_str()))
192        }
193    }
194
195    #[rocket::async_trait]
196    impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
197        where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
198    {
199        type Error = Error<PoolError<M::Error>>;
200
201        type Connection = C;
202
203        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
204            let config: Config = figment.extract()?;
205            let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
206
207            Pool::builder(manager)
208                .max_size(config.max_connections)
209                .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
210                .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
211                .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
212                .runtime(Runtime::Tokio1)
213                .build()
214                .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
215        }
216
217        async fn get(&self) -> Result<Self::Connection, Self::Error> {
218            self.get().await.map_err(Error::Get)
219        }
220
221        async fn close(&self) {
222            <Pool<M, C>>::close(self)
223        }
224    }
225}
226
227#[cfg(feature = "sqlx")]
228mod sqlx {
229    use sqlx::ConnectOptions;
230    use super::{Duration, Error, Config, Figment};
231    use rocket::tracing::level_filters::LevelFilter;
232
233    type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
234
235    // Provide specialized configuration for particular databases.
236    fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
237        #[cfg(feature = "sqlx_sqlite")]
238        if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
239            *o = std::mem::take(o)
240                .busy_timeout(Duration::from_secs(__config.connect_timeout))
241                .create_if_missing(true);
242
243            if let Some(ref exts) = __config.extensions {
244                for ext in exts {
245                    *o = std::mem::take(o).extension(ext.clone());
246                }
247            }
248        }
249    }
250
251    #[rocket::async_trait]
252    impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
253        type Error = Error<sqlx::Error>;
254
255        type Connection = sqlx::pool::PoolConnection<D>;
256
257        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
258            let config = figment.extract::<Config>()?;
259            let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
260            specialize(&mut opts, &config);
261
262            opts = opts.disable_statement_logging();
263            if let Ok(value) = figment.find_value(rocket::Config::LOG_LEVEL) {
264                if let Some(level) = value.as_str().and_then(|v| v.parse().ok()) {
265                    let log_level = match level {
266                        LevelFilter::OFF => log::LevelFilter::Off,
267                        LevelFilter::ERROR => log::LevelFilter::Error,
268                        LevelFilter::WARN => log::LevelFilter::Warn,
269                        LevelFilter::INFO => log::LevelFilter::Info,
270                        LevelFilter::DEBUG => log::LevelFilter::Debug,
271                        LevelFilter::TRACE => log::LevelFilter::Trace,
272                    };
273
274                    opts = opts.log_statements(log_level)
275                        .log_slow_statements(log_level, Duration::default());
276                }
277            }
278
279            Ok(sqlx::pool::PoolOptions::new()
280                .max_connections(config.max_connections as u32)
281                .acquire_timeout(Duration::from_secs(config.connect_timeout))
282                .idle_timeout(config.idle_timeout.map(Duration::from_secs))
283                .min_connections(config.min_connections.unwrap_or_default())
284                .connect_lazy_with(opts))
285        }
286
287        async fn get(&self) -> Result<Self::Connection, Self::Error> {
288            self.acquire().await.map_err(Error::Get)
289        }
290
291        async fn close(&self) {
292            <sqlx::Pool<D>>::close(self).await;
293        }
294    }
295}
296
297#[cfg(feature = "mongodb")]
298mod mongodb {
299    use mongodb::{Client, options::ClientOptions};
300    use super::{Duration, Error, Config, Figment};
301
302    #[rocket::async_trait]
303    impl crate::Pool for Client {
304        type Error = Error<mongodb::error::Error, std::convert::Infallible>;
305
306        type Connection = Client;
307
308        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
309            let config = figment.extract::<Config>()?;
310            let mut opts = ClientOptions::parse(&config.url).await.map_err(Error::Init)?;
311            opts.min_pool_size = config.min_connections;
312            opts.max_pool_size = Some(config.max_connections as u32);
313            opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
314            opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
315            opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
316            Client::with_options(opts).map_err(Error::Init)
317        }
318
319        async fn get(&self) -> Result<Self::Connection, Self::Error> {
320            Ok(self.clone())
321        }
322
323        async fn close(&self) {
324            // nothing to do for mongodb
325        }
326    }
327}