rocket_sync_db_pools/
poolable.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#[allow(unused)]
use std::time::Duration;

use r2d2::ManageConnection;
use rocket::{Rocket, Build};

#[allow(unused_imports)]
use crate::{Config, Error};

/// Trait implemented by `r2d2`-based database adapters.
///
/// # Provided Implementations
///
/// Implementations of `Poolable` are provided for the following types:
///
///   * [`diesel::MysqlConnection`](diesel::MysqlConnection)
///   * [`diesel::PgConnection`](diesel::PgConnection)
///   * [`diesel::SqliteConnection`](diesel::SqliteConnection)
///   * [`postgres::Client`](postgres::Client)
///   * [`rusqlite::Connection`](rusqlite::Connection)
///   * [`memcache::Client`](memcache::Client)
///
/// # Implementation Guide
///
/// As an r2d2-compatible database (or other resource) adapter provider,
/// implementing `Poolable` in your own library will enable Rocket users to
/// consume your adapter with its built-in connection pooling support.
///
/// ## Example
///
/// Consider a library `foo` with the following types:
///
///   * `foo::ConnectionManager`, which implements [`r2d2::ManageConnection`]
///   * `foo::Connection`, the `Connection` associated type of
///     `foo::ConnectionManager`
///   * `foo::Error`, errors resulting from manager instantiation
///
/// In order for Rocket to generate the required code to automatically provision
/// a r2d2 connection pool into application state, the `Poolable` trait needs to
/// be implemented for the connection type. The following example implements
/// `Poolable` for `foo::Connection`:
///
/// ```rust
/// # mod foo {
/// #     use std::fmt;
/// #     use rocket_sync_db_pools::r2d2;
/// #     #[derive(Debug)] pub struct Error;
/// #     impl std::error::Error for Error {  }
/// #     impl fmt::Display for Error {
/// #         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
/// #     }
/// #
/// #     pub struct Connection;
/// #     pub struct ConnectionManager;
/// #
/// #     type Result<T> = std::result::Result<T, Error>;
/// #
/// #     impl ConnectionManager {
/// #         pub fn new(url: &str) -> Result<Self> { Err(Error) }
/// #     }
/// #
/// #     impl self::r2d2::ManageConnection for ConnectionManager {
/// #          type Connection = Connection;
/// #          type Error = Error;
/// #          fn connect(&self) -> Result<Connection> { panic!() }
/// #          fn is_valid(&self, _: &mut Connection) -> Result<()> { panic!() }
/// #          fn has_broken(&self, _: &mut Connection) -> bool { panic!() }
/// #     }
/// # }
/// use std::time::Duration;
/// use rocket::{Rocket, Build};
/// use rocket_sync_db_pools::{r2d2, Error, Config, Poolable, PoolResult};
///
/// impl Poolable for foo::Connection {
///     type Manager = foo::ConnectionManager;
///     type Error = foo::Error;
///
///     fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
///         let config = Config::from(db_name, rocket)?;
///         let manager = foo::ConnectionManager::new(&config.url).map_err(Error::Custom)?;
///         Ok(r2d2::Pool::builder()
///             .max_size(config.pool_size)
///             .connection_timeout(Duration::from_secs(config.timeout as u64))
///             .build(manager)?)
///     }
/// }
/// ```
///
/// In this example, `ConnectionManager::new()` method returns a `foo::Error` on
/// failure. The [`Error`] enum consolidates this type, the `r2d2::Error` type
/// that can result from `r2d2::Pool::builder()`, and the
/// [`figment::Error`](rocket::figment::Error) type from
/// `database::Config::from()`.
///
/// In the event that a connection manager isn't fallible (as is the case with
/// Diesel's r2d2 connection manager, for instance), the associated error type
/// for the `Poolable` implementation should be `std::convert::Infallible`.
///
/// For more concrete example, consult Rocket's existing implementations of
/// [`Poolable`].
pub trait Poolable: Send + Sized + 'static {
    /// The associated connection manager for the given connection type.
    type Manager: ManageConnection<Connection=Self>;

    /// The associated error type in the event that constructing the connection
    /// manager and/or the connection pool fails.
    type Error: std::fmt::Debug;

    /// Creates an `r2d2` connection pool for `Manager::Connection`, returning
    /// the pool on success.
    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
}

/// A type alias for the return type of [`Poolable::pool()`].
#[allow(type_alias_bounds)]
pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;

#[cfg(feature = "diesel_sqlite_pool")]
impl Poolable for diesel::SqliteConnection {
    type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
    type Error = std::convert::Infallible;

    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
        use diesel::{SqliteConnection, connection::SimpleConnection};
        use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};

        #[derive(Debug)]
        struct Customizer;

        impl CustomizeConnection<SqliteConnection, Error> for Customizer {
            fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
                conn.batch_execute("\
                    PRAGMA journal_mode = WAL;\
                    PRAGMA busy_timeout = 5000;\
                    PRAGMA foreign_keys = ON;\
                ").map_err(Error::QueryError)?;

                Ok(())
            }
        }

        let config = Config::from(db_name, rocket)?;
        let manager = ConnectionManager::new(&config.url);
        let pool = Pool::builder()
            .connection_customizer(Box::new(Customizer))
            .max_size(config.pool_size)
            .connection_timeout(Duration::from_secs(config.timeout as u64))
            .build(manager)?;

        Ok(pool)
    }
}

#[cfg(feature = "diesel_postgres_pool")]
impl Poolable for diesel::PgConnection {
    type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
    type Error = std::convert::Infallible;

    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
        let config = Config::from(db_name, rocket)?;
        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
        let pool = r2d2::Pool::builder()
            .max_size(config.pool_size)
            .connection_timeout(Duration::from_secs(config.timeout as u64))
            .build(manager)?;

        Ok(pool)
    }
}

#[cfg(feature = "diesel_mysql_pool")]
impl Poolable for diesel::MysqlConnection {
    type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
    type Error = std::convert::Infallible;

    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
        let config = Config::from(db_name, rocket)?;
        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
        let pool = r2d2::Pool::builder()
            .max_size(config.pool_size)
            .connection_timeout(Duration::from_secs(config.timeout as u64))
            .build(manager)?;

        Ok(pool)
    }
}

// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Client {
    type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
    type Error = postgres::Error;

    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
        let config = Config::from(db_name, rocket)?;
        let url = config.url.parse().map_err(Error::Custom)?;
        let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
        let pool = r2d2::Pool::builder()
            .max_size(config.pool_size)
            .connection_timeout(Duration::from_secs(config.timeout as u64))
            .build(manager)?;

        Ok(pool)
    }
}

#[cfg(feature = "sqlite_pool")]
impl Poolable for rusqlite::Connection {
    type Manager = r2d2_sqlite::SqliteConnectionManager;
    type Error = std::convert::Infallible;

    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
        use rocket::figment::providers::Serialized;

        #[derive(Debug, serde::Deserialize, serde::Serialize)]
        #[serde(rename_all = "snake_case")]
        enum OpenFlag {
            ReadOnly,
            ReadWrite,
            Create,
            Uri,
            Memory,
            NoMutex,
            FullMutex,
            SharedCache,
            PrivateCache,
            Nofollow,
        }

        let figment = Config::figment(db_name, rocket);
        let config: Config = figment.extract()?;
        let open_flags: Vec<OpenFlag> = figment
            .join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
            .extract_inner("open_flags")?;

        let mut flags = rusqlite::OpenFlags::default();
        for flag in open_flags {
            let sql_flag = match flag {
                OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
                OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
                OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
                OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
                OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
                OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
                OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
                OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
                OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
                OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
            };

            flags.insert(sql_flag)
        };

        let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
            .with_flags(flags);

        let pool = r2d2::Pool::builder()
            .max_size(config.pool_size)
            .connection_timeout(Duration::from_secs(config.timeout as u64))
            .build(manager)?;

        Ok(pool)
    }
}

#[cfg(feature = "memcache_pool")]
mod memcache_pool {
    use memcache::{Client, Connectable, MemcacheError};

    use super::*;

    #[derive(Debug)]
    pub struct ConnectionManager {
        urls: Vec<String>,
    }

    impl ConnectionManager {
        pub fn new<C: Connectable>(target: C) -> Self {
            Self { urls: target.get_urls(), }
        }
    }

    impl r2d2::ManageConnection for ConnectionManager {
        type Connection = Client;
        type Error = MemcacheError;

        fn connect(&self) -> Result<Client, MemcacheError> {
            Client::connect(self.urls.clone())
        }

        fn is_valid(&self, connection: &mut Client) -> Result<(), MemcacheError> {
            connection.version().map(|_| ())
        }

        fn has_broken(&self, _connection: &mut Client) -> bool {
            false
        }
    }

    impl super::Poolable for memcache::Client {
        type Manager = ConnectionManager;
        type Error = MemcacheError;

        fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
            let config = Config::from(db_name, rocket)?;
            let manager = ConnectionManager::new(&*config.url);
            let pool = r2d2::Pool::builder()
                .max_size(config.pool_size)
                .connection_timeout(Duration::from_secs(config.timeout as u64))
                .build(manager)?;

            Ok(pool)
        }
    }
}