rocket_db_pools/
pool.rs
1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {std::time::Duration, crate::{Error, Config}};
5
6#[rocket::async_trait]
117pub trait Pool: Sized + Send + Sync + 'static {
118 type Connection;
120
121 type Error: std::error::Error;
123
124 async fn init(figment: &Figment) -> Result<Self, Self::Error>;
137
138 async fn get(&self) -> Result<Self::Connection, Self::Error>;
146
147 async fn close(&self);
153}
154
155#[cfg(feature = "deadpool")]
156mod deadpool_postgres {
157 use deadpool::{Runtime, managed::{Manager, Pool, PoolError, Object}};
158 use super::{Duration, Error, Config, Figment};
159
160 pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
161 fn new(config: &Config) -> Result<Self, Self::Error>;
162 }
163
164 #[cfg(feature = "deadpool_postgres")]
165 impl DeadManager for deadpool_postgres::Manager {
166 fn new(config: &Config) -> Result<Self, Self::Error> {
167 Ok(Self::new(config.url.parse()?, deadpool_postgres::tokio_postgres::NoTls))
168 }
169 }
170
171 #[cfg(feature = "deadpool_redis")]
172 impl DeadManager for deadpool_redis::Manager {
173 fn new(config: &Config) -> Result<Self, Self::Error> {
174 Self::new(config.url.as_str())
175 }
176 }
177
178 #[rocket::async_trait]
179 impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
180 where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
181 {
182 type Error = Error<PoolError<M::Error>>;
183
184 type Connection = C;
185
186 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
187 let config: Config = figment.extract()?;
188 let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
189
190 Pool::builder(manager)
191 .max_size(config.max_connections)
192 .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
193 .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
194 .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
195 .runtime(Runtime::Tokio1)
196 .build()
197 .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
198 }
199
200 async fn get(&self) -> Result<Self::Connection, Self::Error> {
201 self.get().await.map_err(Error::Get)
202 }
203
204 async fn close(&self) {
205 <Pool<M, C>>::close(self)
206 }
207 }
208}
209
210#[cfg(all(feature = "deadpool_09", any(feature = "diesel_postgres", feature = "diesel_mysql")))]
212mod deadpool_old {
213 use deadpool_09::{managed::{Manager, Pool, PoolError, Object, BuildError}, Runtime};
214 use diesel_async::pooled_connection::AsyncDieselConnectionManager;
215
216 use super::{Duration, Error, Config, Figment};
217
218 pub trait DeadManager: Manager + Sized + Send + Sync + 'static {
219 fn new(config: &Config) -> Result<Self, Self::Error>;
220 }
221
222 #[cfg(feature = "diesel_postgres")]
223 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
224 fn new(config: &Config) -> Result<Self, Self::Error> {
225 Ok(Self::new(config.url.as_str()))
226 }
227 }
228
229 #[cfg(feature = "diesel_mysql")]
230 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
231 fn new(config: &Config) -> Result<Self, Self::Error> {
232 Ok(Self::new(config.url.as_str()))
233 }
234 }
235
236 #[rocket::async_trait]
237 impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
238 where M::Type: Send, C: Send + Sync + 'static, M::Error: std::error::Error
239 {
240 type Error = Error<BuildError<M::Error>, PoolError<M::Error>>;
241
242 type Connection = C;
243
244 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
245 let config: Config = figment.extract()?;
246 let manager = M::new(&config).map_err(|e| Error::Init(BuildError::Backend(e)))?;
247
248 Pool::builder(manager)
249 .max_size(config.max_connections)
250 .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
251 .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
252 .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
253 .runtime(Runtime::Tokio1)
254 .build()
255 .map_err(Error::Init)
256 }
257
258 async fn get(&self) -> Result<Self::Connection, Self::Error> {
259 self.get().await.map_err(Error::Get)
260 }
261
262 async fn close(&self) {
263 <Pool<M, C>>::close(self)
264 }
265 }
266}
267
268#[cfg(feature = "sqlx")]
269mod sqlx {
270 use sqlx::ConnectOptions;
271 use super::{Duration, Error, Config, Figment};
272 use rocket::config::LogLevel;
273
274 type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
275
276 fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
278 #[cfg(feature = "sqlx_sqlite")]
279 if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
280 *o = std::mem::take(o)
281 .busy_timeout(Duration::from_secs(__config.connect_timeout))
282 .create_if_missing(true);
283
284 if let Some(ref exts) = __config.extensions {
285 for ext in exts {
286 *o = std::mem::take(o).extension(ext.clone());
287 }
288 }
289 }
290 }
291
292 #[rocket::async_trait]
293 impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
294 type Error = Error<sqlx::Error>;
295
296 type Connection = sqlx::pool::PoolConnection<D>;
297
298 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
299 let config = figment.extract::<Config>()?;
300 let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
301 specialize(&mut opts, &config);
302
303 opts = opts.disable_statement_logging();
304 if let Ok(level) = figment.extract_inner::<LogLevel>(rocket::Config::LOG_LEVEL) {
305 if !matches!(level, LogLevel::Normal | LogLevel::Off) {
306 opts = opts.log_statements(level.into())
307 .log_slow_statements(level.into(), Duration::default());
308 }
309 }
310
311 sqlx::pool::PoolOptions::new()
312 .max_connections(config.max_connections as u32)
313 .acquire_timeout(Duration::from_secs(config.connect_timeout))
314 .idle_timeout(config.idle_timeout.map(Duration::from_secs))
315 .min_connections(config.min_connections.unwrap_or_default())
316 .connect_with(opts)
317 .await
318 .map_err(Error::Init)
319 }
320
321 async fn get(&self) -> Result<Self::Connection, Self::Error> {
322 self.acquire().await.map_err(Error::Get)
323 }
324
325 async fn close(&self) {
326 <sqlx::Pool<D>>::close(self).await;
327 }
328 }
329}
330
331#[cfg(feature = "mongodb")]
332mod mongodb {
333 use mongodb::{Client, options::ClientOptions};
334 use super::{Duration, Error, Config, Figment};
335
336 #[rocket::async_trait]
337 impl crate::Pool for Client {
338 type Error = Error<mongodb::error::Error, std::convert::Infallible>;
339
340 type Connection = Client;
341
342 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
343 let config = figment.extract::<Config>()?;
344 let mut opts = ClientOptions::parse(&config.url).await.map_err(Error::Init)?;
345 opts.min_pool_size = config.min_connections;
346 opts.max_pool_size = Some(config.max_connections as u32);
347 opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
348 opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
349 opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
350 Client::with_options(opts).map_err(Error::Init)
351 }
352
353 async fn get(&self) -> Result<Self::Connection, Self::Error> {
354 Ok(self.clone())
355 }
356
357 async fn close(&self) {
358 }
360 }
361}