rocket_db_pools/
database.rs

1use std::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel, Orbit};
5use rocket::fairing::{self, Fairing, Info, Kind};
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::http::Status;
8
9use rocket::yansi::Paint;
10use rocket::figment::providers::Serialized;
11
12use crate::Pool;
13
14/// Derivable trait which ties a database [`Pool`] with a configuration name.
15///
16/// This trait should rarely, if ever, be implemented manually. Instead, it
17/// should be derived:
18///
19/// ```rust
20/// # #[cfg(feature = "deadpool_redis")] mod _inner {
21/// # use rocket::launch;
22/// use rocket_db_pools::{deadpool_redis, Database};
23///
24/// #[derive(Database)]
25/// #[database("memdb")]
26/// struct Db(deadpool_redis::Pool);
27///
28/// #[launch]
29/// fn rocket() -> _ {
30///     rocket::build().attach(Db::init())
31/// }
32/// # }
33/// ```
34///
35/// See the [`Database` derive](derive@crate::Database) for details.
36pub trait Database: From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static {
37    /// The [`Pool`] type of connections to this database.
38    ///
39    /// When `Database` is derived, this takes the value of the `Inner` type in
40    /// `struct Db(Inner)`.
41    type Pool: Pool;
42
43    /// The configuration name for this database.
44    ///
45    /// When `Database` is derived, this takes the value `"name"` in the
46    /// `#[database("name")]` attribute.
47    const NAME: &'static str;
48
49    /// Returns a fairing that initializes the database and its connection pool.
50    ///
51    /// # Example
52    ///
53    /// ```rust
54    /// # #[cfg(feature = "deadpool_postgres")] mod _inner {
55    /// # use rocket::launch;
56    /// use rocket_db_pools::{deadpool_postgres, Database};
57    ///
58    /// #[derive(Database)]
59    /// #[database("pg_db")]
60    /// struct Db(deadpool_postgres::Pool);
61    ///
62    /// #[launch]
63    /// fn rocket() -> _ {
64    ///     rocket::build().attach(Db::init())
65    /// }
66    /// # }
67    /// ```
68    fn init() -> Initializer<Self> {
69        Initializer::new()
70    }
71
72    /// Returns a reference to the initialized database in `rocket`. The
73    /// initializer fairing returned by `init()` must have already executed for
74    /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
75    /// is attached and either:
76    ///
77    ///   * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
78    ///     application is running. This is always the case in request guards
79    ///     and liftoff fairings,
80    ///   * _or_ Rocket is in the [`Build`](rocket::Build) or
81    ///     [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
82    ///     already been run. This is the case in all fairing callbacks
83    ///     corresponding to fairings attached _after_ the `Initializer`
84    ///     fairing.
85    ///
86    /// # Example
87    ///
88    /// Run database migrations in an ignite fairing. It is imperative that the
89    /// migration fairing be registered _after_ the `init()` fairing.
90    ///
91    /// ```rust
92    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
93    /// # use rocket::launch;
94    /// use rocket::{Rocket, Build};
95    /// use rocket::fairing::{self, AdHoc};
96    ///
97    /// use rocket_db_pools::{sqlx, Database};
98    ///
99    /// #[derive(Database)]
100    /// #[database("sqlite_db")]
101    /// struct Db(sqlx::SqlitePool);
102    ///
103    /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
104    ///     if let Some(db) = Db::fetch(&rocket) {
105    ///         // run migrations using `db`. get the inner type with &db.0.
106    ///         Ok(rocket)
107    ///     } else {
108    ///         Err(rocket)
109    ///     }
110    /// }
111    ///
112    /// #[launch]
113    /// fn rocket() -> _ {
114    ///     rocket::build()
115    ///         .attach(Db::init())
116    ///         .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
117    /// }
118    /// # }
119    /// ```
120    fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
121        if let Some(db) = rocket.state() {
122            return Some(db);
123        }
124
125        let dbtype = std::any::type_name::<Self>().bold().primary();
126        error!("Attempted to fetch unattached database `{}`.", dbtype);
127        info_!("`{}{}` fairing must be attached prior to using this database.",
128            dbtype.linger(), "::init()".resetting());
129        None
130    }
131}
132
133/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
134///
135/// A value of this type can be created for any type `D` that implements
136/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
137/// value of this type _never_ needs to be constructed directly. This
138/// documentation exists purely as a reference.
139///
140/// This fairing initializes a database pool. Specifically, it:
141///
142///   1. Reads the configuration at `database.db_name`, where `db_name` is
143///      [`Database::NAME`].
144///
145///   2. Sets [`Config`](crate::Config) defaults on the configuration figment.
146///
147///   3. Calls [`Pool::init()`].
148///
149///   4. Stores the database instance in managed storage, retrievable via
150///      [`Database::fetch()`].
151///
152/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
153/// the type name `D` unless a name is explicitly provided via
154/// [`Self::with_name()`].
155pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
156
157/// A request guard which retrieves a single connection to a [`Database`].
158///
159/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
160/// single connection to `Db`.
161///
162/// The request guard succeeds if the database was initialized by the
163/// [`Initializer`] fairing and a connection is available within
164/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
165///   * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
166///   status `InternalServerError`. A [`Sentinel`] guards this condition, and so
167///   this type of error is unlikely to occur. A `None` error is returned.
168///   * If a connection is not available within `connect_timeout` seconds or
169///   another error occurs, the guard _fails_ with status `ServiceUnavailable`
170///   and the error is returned in `Some`.
171///
172/// ## Deref
173///
174/// A type of `Connection<Db>` dereferences, mutably and immutably, to the
175/// native database connection type. The [driver table](crate#supported-drivers)
176/// lists the concrete native `Deref` types.
177///
178/// # Example
179///
180/// ```rust
181/// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
182/// # use rocket::get;
183/// # type Pool = rocket_db_pools::sqlx::SqlitePool;
184/// use rocket_db_pools::{Database, Connection};
185///
186/// #[derive(Database)]
187/// #[database("db")]
188/// struct Db(Pool);
189///
190/// #[get("/")]
191/// async fn db_op(db: Connection<Db>) {
192///     // use `&*db` to get an immutable borrow to the native connection type
193///     // use `&mut *db` to get a mutable borrow to the native connection type
194/// }
195/// # }
196/// ```
197pub struct Connection<D: Database>(<D::Pool as Pool>::Connection);
198
199impl<D: Database> Initializer<D> {
200    /// Returns a database initializer fairing for `D`.
201    ///
202    /// This method should never need to be called manually. See the [crate
203    /// docs](crate) for usage information.
204    pub fn new() -> Self {
205        Self(None, std::marker::PhantomData)
206    }
207
208    /// Returns a database initializer fairing for `D` with name `name`.
209    ///
210    /// This method should never need to be called manually. See the [crate
211    /// docs](crate) for usage information.
212    pub fn with_name(name: &'static str) -> Self {
213        Self(Some(name), std::marker::PhantomData)
214    }
215}
216
217impl<D: Database> Connection<D> {
218    /// Returns the internal connection value. See the [`Connection` Deref
219    /// column](crate#supported-drivers) for the expected type of this value.
220    ///
221    /// Note that `Connection<D>` derefs to the internal connection type, so
222    /// using this method is likely unnecessary. See [deref](Connection#deref)
223    /// for examples.
224    ///
225    /// # Example
226    ///
227    /// ```rust
228    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
229    /// # use rocket::get;
230    /// # type Pool = rocket_db_pools::sqlx::SqlitePool;
231    /// use rocket_db_pools::{Database, Connection};
232    ///
233    /// #[derive(Database)]
234    /// #[database("db")]
235    /// struct Db(Pool);
236    ///
237    /// #[get("/")]
238    /// async fn db_op(db: Connection<Db>) {
239    ///     let inner = db.into_inner();
240    /// }
241    /// # }
242    /// ```
243    pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
244        self.0
245    }
246}
247
248#[rocket::async_trait]
249impl<D: Database> Fairing for Initializer<D> {
250    fn info(&self) -> Info {
251        Info {
252            name: self.0.unwrap_or(std::any::type_name::<Self>()),
253            kind: Kind::Ignite | Kind::Shutdown,
254        }
255    }
256
257    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
258        let workers: usize = rocket.figment()
259            .extract_inner(rocket::Config::WORKERS)
260            .unwrap_or_else(|_| rocket::Config::default().workers);
261
262        let figment = rocket.figment()
263            .focus(&format!("databases.{}", D::NAME))
264            .join(Serialized::default("max_connections", workers * 4))
265            .join(Serialized::default("connect_timeout", 5));
266
267        match <D::Pool>::init(&figment).await {
268            Ok(pool) => Ok(rocket.manage(D::from(pool))),
269            Err(e) => {
270                error!("failed to initialize database: {}", e);
271                Err(rocket)
272            }
273        }
274    }
275
276    async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
277        if let Some(db) = D::fetch(rocket) {
278            db.close().await;
279        }
280    }
281}
282
283#[rocket::async_trait]
284impl<'r, D: Database> FromRequest<'r> for Connection<D> {
285    type Error = Option<<D::Pool as Pool>::Error>;
286
287    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
288        match D::fetch(req.rocket()) {
289            Some(db) => match db.get().await {
290                Ok(conn) => Outcome::Success(Connection(conn)),
291                Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
292            },
293            None => Outcome::Error((Status::InternalServerError, None)),
294        }
295    }
296}
297
298impl<D: Database> Sentinel for Connection<D> {
299    fn abort(rocket: &Rocket<Ignite>) -> bool {
300        D::fetch(rocket).is_none()
301    }
302}
303
304impl<D: Database> Deref for Connection<D> {
305    type Target = <D::Pool as Pool>::Connection;
306
307    fn deref(&self) -> &Self::Target {
308        &self.0
309    }
310}
311
312impl<D: Database> DerefMut for Connection<D> {
313    fn deref_mut(&mut self) -> &mut Self::Target {
314        &mut self.0
315    }
316}