rocket_db_pools/
database.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

use rocket::{error, Build, Ignite, Phase, Rocket, Sentinel, Orbit};
use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::request::{FromRequest, Outcome, Request};
use rocket::figment::providers::Serialized;
use rocket::http::Status;

use crate::Pool;

/// Derivable trait which ties a database [`Pool`] with a configuration name.
///
/// This trait should rarely, if ever, be implemented manually. Instead, it
/// should be derived:
///
/// ```rust
/// # #[cfg(feature = "deadpool_redis")] mod _inner {
/// # use rocket::launch;
/// use rocket_db_pools::{deadpool_redis, Database};
///
/// #[derive(Database)]
/// #[database("memdb")]
/// struct Db(deadpool_redis::Pool);
///
/// #[launch]
/// fn rocket() -> _ {
///     rocket::build().attach(Db::init())
/// }
/// # }
/// ```
///
/// See the [`Database` derive](derive@crate::Database) for details.
pub trait Database: From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static {
    /// The [`Pool`] type of connections to this database.
    ///
    /// When `Database` is derived, this takes the value of the `Inner` type in
    /// `struct Db(Inner)`.
    type Pool: Pool;

    /// The configuration name for this database.
    ///
    /// When `Database` is derived, this takes the value `"name"` in the
    /// `#[database("name")]` attribute.
    const NAME: &'static str;

    /// Returns a fairing that initializes the database and its connection pool.
    ///
    /// # Example
    ///
    /// ```rust
    /// # #[cfg(feature = "deadpool_postgres")] mod _inner {
    /// # use rocket::launch;
    /// use rocket_db_pools::{deadpool_postgres, Database};
    ///
    /// #[derive(Database)]
    /// #[database("pg_db")]
    /// struct Db(deadpool_postgres::Pool);
    ///
    /// #[launch]
    /// fn rocket() -> _ {
    ///     rocket::build().attach(Db::init())
    /// }
    /// # }
    /// ```
    fn init() -> Initializer<Self> {
        Initializer::new()
    }

    /// Returns a reference to the initialized database in `rocket`. The
    /// initializer fairing returned by `init()` must have already executed for
    /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
    /// is attached and either:
    ///
    ///   * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
    ///     application is running. This is always the case in request guards
    ///     and liftoff fairings,
    ///   * _or_ Rocket is in the [`Build`](rocket::Build) or
    ///     [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
    ///     already been run. This is the case in all fairing callbacks
    ///     corresponding to fairings attached _after_ the `Initializer`
    ///     fairing.
    ///
    /// # Example
    ///
    /// Run database migrations in an ignite fairing. It is imperative that the
    /// migration fairing be registered _after_ the `init()` fairing.
    ///
    /// ```rust
    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
    /// # use rocket::launch;
    /// use rocket::{Rocket, Build};
    /// use rocket::fairing::{self, AdHoc};
    ///
    /// use rocket_db_pools::{sqlx, Database};
    ///
    /// #[derive(Database)]
    /// #[database("sqlite_db")]
    /// struct Db(sqlx::SqlitePool);
    ///
    /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
    ///     if let Some(db) = Db::fetch(&rocket) {
    ///         // run migrations using `db`. get the inner type with &db.0.
    ///         Ok(rocket)
    ///     } else {
    ///         Err(rocket)
    ///     }
    /// }
    ///
    /// #[launch]
    /// fn rocket() -> _ {
    ///     rocket::build()
    ///         .attach(Db::init())
    ///         .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
    /// }
    /// # }
    /// ```
    fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
        if let Some(db) = rocket.state() {
            return Some(db);
        }

        let conn = std::any::type_name::<Self>();
        error!("`{conn}::init()` is not attached\n\
            the fairing must be attached to use `{conn}` in routes.");

        None
    }
}

/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
///
/// A value of this type can be created for any type `D` that implements
/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
/// value of this type _never_ needs to be constructed directly. This
/// documentation exists purely as a reference.
///
/// This fairing initializes a database pool. Specifically, it:
///
///   1. Reads the configuration at `database.db_name`, where `db_name` is
///      [`Database::NAME`].
///
///   2. Sets [`Config`](crate::Config) defaults on the configuration figment.
///
///   3. Calls [`Pool::init()`].
///
///   4. Stores the database instance in managed storage, retrievable via
///      [`Database::fetch()`].
///
/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
/// the type name `D` unless a name is explicitly provided via
/// [`Self::with_name()`].
pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);

/// A request guard which retrieves a single connection to a [`Database`].
///
/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
/// single connection to `Db`.
///
/// The request guard succeeds if the database was initialized by the
/// [`Initializer`] fairing and a connection is available within
/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
///   * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
///   status `InternalServerError`. A [`Sentinel`] guards this condition, and so
///   this type of error is unlikely to occur. A `None` error is returned.
///   * If a connection is not available within `connect_timeout` seconds or
///   another error occurs, the guard _fails_ with status `ServiceUnavailable`
///   and the error is returned in `Some`.
///
/// ## Deref
///
/// A type of `Connection<Db>` dereferences, mutably and immutably, to the
/// native database connection type. The [driver table](crate#supported-drivers)
/// lists the concrete native `Deref` types.
///
/// # Example
///
/// ```rust
/// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
/// # use rocket::get;
/// # type Pool = rocket_db_pools::sqlx::SqlitePool;
/// use rocket_db_pools::{Database, Connection};
///
/// #[derive(Database)]
/// #[database("db")]
/// struct Db(Pool);
///
/// #[get("/")]
/// async fn db_op(db: Connection<Db>) {
///     // use `&*db` to get an immutable borrow to the native connection type
///     // use `&mut *db` to get a mutable borrow to the native connection type
/// }
/// # }
/// ```
pub struct Connection<D: Database>(<D::Pool as Pool>::Connection);

impl<D: Database> Initializer<D> {
    /// Returns a database initializer fairing for `D`.
    ///
    /// This method should never need to be called manually. See the [crate
    /// docs](crate) for usage information.
    pub fn new() -> Self {
        Self(None, std::marker::PhantomData)
    }

    /// Returns a database initializer fairing for `D` with name `name`.
    ///
    /// This method should never need to be called manually. See the [crate
    /// docs](crate) for usage information.
    pub fn with_name(name: &'static str) -> Self {
        Self(Some(name), std::marker::PhantomData)
    }
}

impl<D: Database> Connection<D> {
    /// Returns the internal connection value. See the [`Connection` Deref
    /// column](crate#supported-drivers) for the expected type of this value.
    ///
    /// Note that `Connection<D>` derefs to the internal connection type, so
    /// using this method is likely unnecessary. See [deref](Connection#deref)
    /// for examples.
    ///
    /// # Example
    ///
    /// ```rust
    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
    /// # use rocket::get;
    /// # type Pool = rocket_db_pools::sqlx::SqlitePool;
    /// use rocket_db_pools::{Database, Connection};
    ///
    /// #[derive(Database)]
    /// #[database("db")]
    /// struct Db(Pool);
    ///
    /// #[get("/")]
    /// async fn db_op(db: Connection<Db>) {
    ///     let inner = db.into_inner();
    /// }
    /// # }
    /// ```
    pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
        self.0
    }
}

#[rocket::async_trait]
impl<D: Database> Fairing for Initializer<D> {
    fn info(&self) -> Info {
        Info {
            name: self.0.unwrap_or(std::any::type_name::<Self>()),
            kind: Kind::Ignite | Kind::Shutdown,
        }
    }

    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
        let workers: usize = rocket.figment()
            .extract_inner(rocket::Config::WORKERS)
            .unwrap_or_else(|_| rocket::Config::default().workers);

        let figment = rocket.figment()
            .focus(&format!("databases.{}", D::NAME))
            .join(Serialized::default("max_connections", workers * 4))
            .join(Serialized::default("connect_timeout", 5));

        match <D::Pool>::init(&figment).await {
            Ok(pool) => Ok(rocket.manage(D::from(pool))),
            Err(e) => {
                error!("database initialization failed: {e}");
                Err(rocket)
            }
        }
    }

    async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
        if let Some(db) = D::fetch(rocket) {
            db.close().await;
        }
    }
}

#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for Connection<D> {
    type Error = Option<<D::Pool as Pool>::Error>;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        match D::fetch(req.rocket()) {
            Some(db) => match db.get().await {
                Ok(conn) => Outcome::Success(Connection(conn)),
                Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
            },
            None => Outcome::Error((Status::InternalServerError, None)),
        }
    }
}

impl<D: Database> Sentinel for Connection<D> {
    fn abort(rocket: &Rocket<Ignite>) -> bool {
        D::fetch(rocket).is_none()
    }
}

impl<D: Database> Deref for Connection<D> {
    type Target = <D::Pool as Pool>::Connection;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<D: Database> DerefMut for Connection<D> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}