rocket_db_pools/pool.rs
1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {std::time::Duration, crate::{Error, Config}};
5
6/// Generic [`Database`](crate::Database) driver connection pool trait.
7///
8/// This trait provides a generic interface to various database pooling
9/// implementations in the Rust ecosystem. It can be implemented by anyone, but
10/// this crate provides implementations for common drivers.
11///
12/// **Implementations of this trait outside of this crate should be rare. You
13/// _do not_ need to implement this trait or understand its specifics to use
14/// this crate.**
15///
16/// ## Async Trait
17///
18/// [`Pool`] is an _async_ trait. Implementations of `Pool` must be decorated
19/// with an attribute of `#[async_trait]`:
20///
21/// ```rust
22/// # #[macro_use] extern crate rocket;
23/// use rocket::figment::Figment;
24/// use rocket_db_pools::Pool;
25///
26/// # struct MyPool;
27/// # type Connection = ();
28/// # type Error = std::convert::Infallible;
29/// #[rocket::async_trait]
30/// impl Pool for MyPool {
31/// type Connection = Connection;
32///
33/// type Error = Error;
34///
35/// async fn init(figment: &Figment) -> Result<Self, Self::Error> {
36/// todo!("initialize and return an instance of the pool");
37/// }
38///
39/// async fn get(&self) -> Result<Self::Connection, Self::Error> {
40/// todo!("fetch one connection from the pool");
41/// }
42///
43/// async fn close(&self) {
44/// todo!("gracefully shutdown connection pool");
45/// }
46/// }
47/// ```
48///
49/// ## Implementing
50///
51/// Implementations of `Pool` typically trace the following outline:
52///
53/// 1. The `Error` associated type is set to [`Error`].
54///
55/// 2. A [`Config`] is [extracted](Figment::extract()) from the `figment`
56/// passed to init.
57///
58/// 3. The pool is initialized and returned in `init()`, wrapping
59/// initialization errors in [`Error::Init`].
60///
61/// 4. A connection is retrieved in `get()`, wrapping errors in
62/// [`Error::Get`].
63///
64/// Concretely, this looks like:
65///
66/// ```rust
67/// use rocket::figment::Figment;
68/// use rocket_db_pools::{Pool, Config, Error};
69/// #
70/// # type InitError = std::convert::Infallible;
71/// # type GetError = std::convert::Infallible;
72/// # type Connection = ();
73/// #
74/// # struct MyPool(Config);
75/// # impl MyPool {
76/// # fn new(c: Config) -> Result<Self, InitError> {
77/// # Ok(Self(c))
78/// # }
79/// #
80/// # fn acquire(&self) -> Result<Connection, GetError> {
81/// # Ok(())
82/// # }
83/// #
84/// # async fn shutdown(&self) { }
85/// # }
86///
87/// #[rocket::async_trait]
88/// impl Pool for MyPool {
89/// type Connection = Connection;
90///
91/// type Error = Error<InitError, GetError>;
92///
93/// async fn init(figment: &Figment) -> Result<Self, Self::Error> {
94/// // Extract the config from `figment`.
95/// let config: Config = figment.extract()?;
96///
97/// // Read config values, initialize `MyPool`. Map errors of type
98/// // `InitError` to `Error<InitError, _>` with `Error::Init`.
99/// let pool = MyPool::new(config).map_err(Error::Init)?;
100///
101/// // Return the fully initialized pool.
102/// Ok(pool)
103/// }
104///
105/// async fn get(&self) -> Result<Self::Connection, Self::Error> {
106/// // Get one connection from the pool, here via an `acquire()` method.
107/// // Map errors of type `GetError` to `Error<_, GetError>`.
108/// self.acquire().map_err(Error::Get)
109/// }
110///
111/// async fn close(&self) {
112/// self.shutdown().await;
113/// }
114/// }
115/// ```
116#[rocket::async_trait]
117pub trait Pool: Sized + Send + Sync + 'static {
118 /// The connection type managed by this pool, returned by [`Self::get()`].
119 type Connection;
120
121 /// The error type returned by [`Self::init()`] and [`Self::get()`].
122 type Error: std::error::Error;
123
124 /// Constructs a pool from a [Value](rocket::figment::value::Value).
125 ///
126 /// It is up to each implementor of `Pool` to define its accepted
127 /// configuration value(s) via the `Config` associated type. Most
128 /// integrations provided in `rocket_db_pools` use [`Config`], which
129 /// accepts a (required) `url` and an (optional) `pool_size`.
130 ///
131 /// ## Errors
132 ///
133 /// This method returns an error if the configuration is not compatible, or
134 /// if creating a pool failed due to an unavailable database server,
135 /// insufficient resources, or another database-specific error.
136 async fn init(figment: &Figment) -> Result<Self, Self::Error>;
137
138 /// Asynchronously retrieves a connection from the factory or pool.
139 ///
140 /// ## Errors
141 ///
142 /// This method returns an error if a connection could not be retrieved,
143 /// such as a preconfigured timeout elapsing or when the database server is
144 /// unavailable.
145 async fn get(&self) -> Result<Self::Connection, Self::Error>;
146
147 /// Shutdown the connection pool, disallowing any new connections from being
148 /// retrieved and waking up any tasks with active connections.
149 ///
150 /// The returned future may either resolve when all connections are known to
151 /// have closed or at any point prior. Details are implementation specific.
152 async fn close(&self);
153}
154
155#[cfg(feature = "deadpool")]
156mod deadpool_postgres {
157 use deadpool::{Runtime, managed::{Manager, Pool, PoolError, Object}};
158 use super::{Duration, Error, Config, Figment};
159
160 #[cfg(feature = "diesel")]
161 use diesel_async::pooled_connection::AsyncDieselConnectionManager;
162
163 pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
164 fn new(config: &Config) -> Result<Self, Self::Error>;
165 }
166
167 #[cfg(feature = "deadpool_postgres")]
168 impl DeadManager for deadpool_postgres::Manager {
169 fn new(config: &Config) -> Result<Self, Self::Error> {
170 Ok(Self::new(config.url.parse()?, deadpool_postgres::tokio_postgres::NoTls))
171 }
172 }
173
174 #[cfg(feature = "deadpool_redis")]
175 impl DeadManager for deadpool_redis::Manager {
176 fn new(config: &Config) -> Result<Self, Self::Error> {
177 Self::new(config.url.as_str())
178 }
179 }
180
181 #[cfg(feature = "diesel_postgres")]
182 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
183 fn new(config: &Config) -> Result<Self, Self::Error> {
184 Ok(Self::new(config.url.as_str()))
185 }
186 }
187
188 #[cfg(feature = "diesel_mysql")]
189 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
190 fn new(config: &Config) -> Result<Self, Self::Error> {
191 Ok(Self::new(config.url.as_str()))
192 }
193 }
194
195 #[rocket::async_trait]
196 impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
197 where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
198 {
199 type Error = Error<PoolError<M::Error>>;
200
201 type Connection = C;
202
203 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
204 let config: Config = figment.extract()?;
205 let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
206
207 Pool::builder(manager)
208 .max_size(config.max_connections)
209 .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
210 .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
211 .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
212 .runtime(Runtime::Tokio1)
213 .build()
214 .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
215 }
216
217 async fn get(&self) -> Result<Self::Connection, Self::Error> {
218 self.get().await.map_err(Error::Get)
219 }
220
221 async fn close(&self) {
222 <Pool<M, C>>::close(self)
223 }
224 }
225}
226
227#[cfg(feature = "sqlx")]
228mod sqlx {
229 use sqlx::ConnectOptions;
230 use super::{Duration, Error, Config, Figment};
231 use rocket::tracing::level_filters::LevelFilter;
232
233 type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
234
235 // Provide specialized configuration for particular databases.
236 fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
237 #[cfg(feature = "sqlx_sqlite")]
238 if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
239 *o = std::mem::take(o)
240 .busy_timeout(Duration::from_secs(__config.connect_timeout))
241 .create_if_missing(true);
242
243 if let Some(ref exts) = __config.extensions {
244 for ext in exts {
245 *o = std::mem::take(o).extension(ext.clone());
246 }
247 }
248 }
249 }
250
251 #[rocket::async_trait]
252 impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
253 type Error = Error<sqlx::Error>;
254
255 type Connection = sqlx::pool::PoolConnection<D>;
256
257 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
258 let config = figment.extract::<Config>()?;
259 let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
260 specialize(&mut opts, &config);
261
262 opts = opts.disable_statement_logging();
263 if let Ok(value) = figment.find_value(rocket::Config::LOG_LEVEL) {
264 if let Some(level) = value.as_str().and_then(|v| v.parse().ok()) {
265 let log_level = match level {
266 LevelFilter::OFF => log::LevelFilter::Off,
267 LevelFilter::ERROR => log::LevelFilter::Error,
268 LevelFilter::WARN => log::LevelFilter::Warn,
269 LevelFilter::INFO => log::LevelFilter::Info,
270 LevelFilter::DEBUG => log::LevelFilter::Debug,
271 LevelFilter::TRACE => log::LevelFilter::Trace,
272 };
273
274 opts = opts.log_statements(log_level)
275 .log_slow_statements(log_level, Duration::default());
276 }
277 }
278
279 Ok(sqlx::pool::PoolOptions::new()
280 .max_connections(config.max_connections as u32)
281 .acquire_timeout(Duration::from_secs(config.connect_timeout))
282 .idle_timeout(config.idle_timeout.map(Duration::from_secs))
283 .min_connections(config.min_connections.unwrap_or_default())
284 .connect_lazy_with(opts))
285 }
286
287 async fn get(&self) -> Result<Self::Connection, Self::Error> {
288 self.acquire().await.map_err(Error::Get)
289 }
290
291 async fn close(&self) {
292 <sqlx::Pool<D>>::close(self).await;
293 }
294 }
295}
296
297#[cfg(feature = "mongodb")]
298mod mongodb {
299 use mongodb::{Client, options::ClientOptions};
300 use super::{Duration, Error, Config, Figment};
301
302 #[rocket::async_trait]
303 impl crate::Pool for Client {
304 type Error = Error<mongodb::error::Error, std::convert::Infallible>;
305
306 type Connection = Client;
307
308 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
309 let config = figment.extract::<Config>()?;
310 let mut opts = ClientOptions::parse(&config.url).await.map_err(Error::Init)?;
311 opts.min_pool_size = config.min_connections;
312 opts.max_pool_size = Some(config.max_connections as u32);
313 opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
314 opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
315 opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
316 Client::with_options(opts).map_err(Error::Init)
317 }
318
319 async fn get(&self) -> Result<Self::Connection, Self::Error> {
320 Ok(self.clone())
321 }
322
323 async fn close(&self) {
324 // nothing to do for mongodb
325 }
326 }
327}