rocket_sync_db_pools/
poolable.rs
1#[allow(unused)]
2use std::time::Duration;
3
4use r2d2::ManageConnection;
5use rocket::{Rocket, Build};
6
7#[allow(unused_imports)]
8use crate::{Config, Error};
9
10pub trait Poolable: Send + Sized + 'static {
102 type Manager: ManageConnection<Connection=Self>;
104
105 type Error: std::fmt::Debug;
108
109 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
112}
113
114#[allow(type_alias_bounds)]
116pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
117
118#[cfg(feature = "diesel_sqlite_pool")]
119impl Poolable for diesel::SqliteConnection {
120 type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
121 type Error = std::convert::Infallible;
122
123 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
124 use diesel::{SqliteConnection, connection::SimpleConnection};
125 use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
126
127 #[derive(Debug)]
128 struct Customizer;
129
130 impl CustomizeConnection<SqliteConnection, Error> for Customizer {
131 fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
132 conn.batch_execute("\
133 PRAGMA journal_mode = WAL;\
134 PRAGMA busy_timeout = 5000;\
135 PRAGMA foreign_keys = ON;\
136 ").map_err(Error::QueryError)?;
137
138 Ok(())
139 }
140 }
141
142 let config = Config::from(db_name, rocket)?;
143 let manager = ConnectionManager::new(&config.url);
144 let pool = Pool::builder()
145 .connection_customizer(Box::new(Customizer))
146 .max_size(config.pool_size)
147 .connection_timeout(Duration::from_secs(config.timeout as u64))
148 .build(manager)?;
149
150 Ok(pool)
151 }
152}
153
154#[cfg(feature = "diesel_postgres_pool")]
155impl Poolable for diesel::PgConnection {
156 type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
157 type Error = std::convert::Infallible;
158
159 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
160 let config = Config::from(db_name, rocket)?;
161 let manager = diesel::r2d2::ConnectionManager::new(&config.url);
162 let pool = r2d2::Pool::builder()
163 .max_size(config.pool_size)
164 .connection_timeout(Duration::from_secs(config.timeout as u64))
165 .build(manager)?;
166
167 Ok(pool)
168 }
169}
170
171#[cfg(feature = "diesel_mysql_pool")]
172impl Poolable for diesel::MysqlConnection {
173 type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
174 type Error = std::convert::Infallible;
175
176 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
177 let config = Config::from(db_name, rocket)?;
178 let manager = diesel::r2d2::ConnectionManager::new(&config.url);
179 let pool = r2d2::Pool::builder()
180 .max_size(config.pool_size)
181 .connection_timeout(Duration::from_secs(config.timeout as u64))
182 .build(manager)?;
183
184 Ok(pool)
185 }
186}
187
188#[cfg(feature = "postgres_pool")]
190impl Poolable for postgres::Client {
191 type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
192 type Error = postgres::Error;
193
194 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
195 let config = Config::from(db_name, rocket)?;
196 let url = config.url.parse().map_err(Error::Custom)?;
197 let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
198 let pool = r2d2::Pool::builder()
199 .max_size(config.pool_size)
200 .connection_timeout(Duration::from_secs(config.timeout as u64))
201 .build(manager)?;
202
203 Ok(pool)
204 }
205}
206
207#[cfg(feature = "sqlite_pool")]
208impl Poolable for rusqlite::Connection {
209 type Manager = r2d2_sqlite::SqliteConnectionManager;
210 type Error = std::convert::Infallible;
211
212 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
213 use rocket::figment::providers::Serialized;
214
215 #[derive(Debug, serde::Deserialize, serde::Serialize)]
216 #[serde(rename_all = "snake_case")]
217 enum OpenFlag {
218 ReadOnly,
219 ReadWrite,
220 Create,
221 Uri,
222 Memory,
223 NoMutex,
224 FullMutex,
225 SharedCache,
226 PrivateCache,
227 Nofollow,
228 }
229
230 let figment = Config::figment(db_name, rocket);
231 let config: Config = figment.extract()?;
232 let open_flags: Vec<OpenFlag> = figment
233 .join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
234 .extract_inner("open_flags")?;
235
236 let mut flags = rusqlite::OpenFlags::default();
237 for flag in open_flags {
238 let sql_flag = match flag {
239 OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
240 OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
241 OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
242 OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
243 OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
244 OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
245 OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
246 OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
247 OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
248 OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
249 };
250
251 flags.insert(sql_flag)
252 };
253
254 let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
255 .with_flags(flags);
256
257 let pool = r2d2::Pool::builder()
258 .max_size(config.pool_size)
259 .connection_timeout(Duration::from_secs(config.timeout as u64))
260 .build(manager)?;
261
262 Ok(pool)
263 }
264}
265
266#[cfg(feature = "memcache_pool")]
267mod memcache_pool {
268 use memcache::{Client, Connectable, MemcacheError};
269
270 use super::*;
271
272 #[derive(Debug)]
273 pub struct ConnectionManager {
274 urls: Vec<String>,
275 }
276
277 impl ConnectionManager {
278 pub fn new<C: Connectable>(target: C) -> Self {
279 Self { urls: target.get_urls(), }
280 }
281 }
282
283 impl r2d2::ManageConnection for ConnectionManager {
284 type Connection = Client;
285 type Error = MemcacheError;
286
287 fn connect(&self) -> Result<Client, MemcacheError> {
288 Client::connect(self.urls.clone())
289 }
290
291 fn is_valid(&self, connection: &mut Client) -> Result<(), MemcacheError> {
292 connection.version().map(|_| ())
293 }
294
295 fn has_broken(&self, _connection: &mut Client) -> bool {
296 false
297 }
298 }
299
300 impl super::Poolable for memcache::Client {
301 type Manager = ConnectionManager;
302 type Error = MemcacheError;
303
304 fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
305 let config = Config::from(db_name, rocket)?;
306 let manager = ConnectionManager::new(&*config.url);
307 let pool = r2d2::Pool::builder()
308 .max_size(config.pool_size)
309 .connection_timeout(Duration::from_secs(config.timeout as u64))
310 .build(manager)?;
311
312 Ok(pool)
313 }
314 }
315}