rocket_db_pools/database.rs
1use std::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use rocket::{error, Build, Ignite, Phase, Rocket, Sentinel, Orbit};
5use rocket::fairing::{self, Fairing, Info, Kind};
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::figment::providers::Serialized;
8use rocket::http::Status;
9
10use crate::Pool;
11
12/// Derivable trait which ties a database [`Pool`] with a configuration name.
13///
14/// This trait should rarely, if ever, be implemented manually. Instead, it
15/// should be derived:
16///
17/// ```rust
18/// # #[cfg(feature = "deadpool_redis")] mod _inner {
19/// # use rocket::launch;
20/// use rocket_db_pools::{deadpool_redis, Database};
21///
22/// #[derive(Database)]
23/// #[database("memdb")]
24/// struct Db(deadpool_redis::Pool);
25///
26/// #[launch]
27/// fn rocket() -> _ {
28/// rocket::build().attach(Db::init())
29/// }
30/// # }
31/// ```
32///
33/// See the [`Database` derive](derive@crate::Database) for details.
34pub trait Database: From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static {
35 /// The [`Pool`] type of connections to this database.
36 ///
37 /// When `Database` is derived, this takes the value of the `Inner` type in
38 /// `struct Db(Inner)`.
39 type Pool: Pool;
40
41 /// The configuration name for this database.
42 ///
43 /// When `Database` is derived, this takes the value `"name"` in the
44 /// `#[database("name")]` attribute.
45 const NAME: &'static str;
46
47 /// Returns a fairing that initializes the database and its connection pool.
48 ///
49 /// # Example
50 ///
51 /// ```rust
52 /// # #[cfg(feature = "deadpool_postgres")] mod _inner {
53 /// # use rocket::launch;
54 /// use rocket_db_pools::{deadpool_postgres, Database};
55 ///
56 /// #[derive(Database)]
57 /// #[database("pg_db")]
58 /// struct Db(deadpool_postgres::Pool);
59 ///
60 /// #[launch]
61 /// fn rocket() -> _ {
62 /// rocket::build().attach(Db::init())
63 /// }
64 /// # }
65 /// ```
66 fn init() -> Initializer<Self> {
67 Initializer::new()
68 }
69
70 /// Returns a reference to the initialized database in `rocket`. The
71 /// initializer fairing returned by `init()` must have already executed for
72 /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
73 /// is attached and either:
74 ///
75 /// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
76 /// application is running. This is always the case in request guards
77 /// and liftoff fairings,
78 /// * _or_ Rocket is in the [`Build`](rocket::Build) or
79 /// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
80 /// already been run. This is the case in all fairing callbacks
81 /// corresponding to fairings attached _after_ the `Initializer`
82 /// fairing.
83 ///
84 /// # Example
85 ///
86 /// Run database migrations in an ignite fairing. It is imperative that the
87 /// migration fairing be registered _after_ the `init()` fairing.
88 ///
89 /// ```rust
90 /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
91 /// # use rocket::launch;
92 /// use rocket::{Rocket, Build};
93 /// use rocket::fairing::{self, AdHoc};
94 ///
95 /// use rocket_db_pools::{sqlx, Database};
96 ///
97 /// #[derive(Database)]
98 /// #[database("sqlite_db")]
99 /// struct Db(sqlx::SqlitePool);
100 ///
101 /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
102 /// if let Some(db) = Db::fetch(&rocket) {
103 /// // run migrations using `db`. get the inner type with &db.0.
104 /// Ok(rocket)
105 /// } else {
106 /// Err(rocket)
107 /// }
108 /// }
109 ///
110 /// #[launch]
111 /// fn rocket() -> _ {
112 /// rocket::build()
113 /// .attach(Db::init())
114 /// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
115 /// }
116 /// # }
117 /// ```
118 fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
119 if let Some(db) = rocket.state() {
120 return Some(db);
121 }
122
123 let conn = std::any::type_name::<Self>();
124 error!("`{conn}::init()` is not attached\n\
125 the fairing must be attached to use `{conn}` in routes.");
126
127 None
128 }
129}
130
131/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
132///
133/// A value of this type can be created for any type `D` that implements
134/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
135/// value of this type _never_ needs to be constructed directly. This
136/// documentation exists purely as a reference.
137///
138/// This fairing initializes a database pool. Specifically, it:
139///
140/// 1. Reads the configuration at `database.db_name`, where `db_name` is
141/// [`Database::NAME`].
142///
143/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment.
144///
145/// 3. Calls [`Pool::init()`].
146///
147/// 4. Stores the database instance in managed storage, retrievable via
148/// [`Database::fetch()`].
149///
150/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
151/// the type name `D` unless a name is explicitly provided via
152/// [`Self::with_name()`].
153pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
154
155/// A request guard which retrieves a single connection to a [`Database`].
156///
157/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
158/// single connection to `Db`.
159///
160/// The request guard succeeds if the database was initialized by the
161/// [`Initializer`] fairing and a connection is available within
162/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
163/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
164/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so
165/// this type of error is unlikely to occur. A `None` error is returned.
166/// * If a connection is not available within `connect_timeout` seconds or
167/// another error occurs, the guard _fails_ with status `ServiceUnavailable`
168/// and the error is returned in `Some`.
169///
170/// ## Deref
171///
172/// A type of `Connection<Db>` dereferences, mutably and immutably, to the
173/// native database connection type. The [driver table](crate#supported-drivers)
174/// lists the concrete native `Deref` types.
175///
176/// # Example
177///
178/// ```rust
179/// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
180/// # use rocket::get;
181/// # type Pool = rocket_db_pools::sqlx::SqlitePool;
182/// use rocket_db_pools::{Database, Connection};
183///
184/// #[derive(Database)]
185/// #[database("db")]
186/// struct Db(Pool);
187///
188/// #[get("/")]
189/// async fn db_op(db: Connection<Db>) {
190/// // use `&*db` to get an immutable borrow to the native connection type
191/// // use `&mut *db` to get a mutable borrow to the native connection type
192/// }
193/// # }
194/// ```
195pub struct Connection<D: Database>(<D::Pool as Pool>::Connection);
196
197impl<D: Database> Initializer<D> {
198 /// Returns a database initializer fairing for `D`.
199 ///
200 /// This method should never need to be called manually. See the [crate
201 /// docs](crate) for usage information.
202 pub fn new() -> Self {
203 Self(None, std::marker::PhantomData)
204 }
205
206 /// Returns a database initializer fairing for `D` with name `name`.
207 ///
208 /// This method should never need to be called manually. See the [crate
209 /// docs](crate) for usage information.
210 pub fn with_name(name: &'static str) -> Self {
211 Self(Some(name), std::marker::PhantomData)
212 }
213}
214
215impl<D: Database> Connection<D> {
216 /// Returns the internal connection value. See the [`Connection` Deref
217 /// column](crate#supported-drivers) for the expected type of this value.
218 ///
219 /// Note that `Connection<D>` derefs to the internal connection type, so
220 /// using this method is likely unnecessary. See [deref](Connection#deref)
221 /// for examples.
222 ///
223 /// # Example
224 ///
225 /// ```rust
226 /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
227 /// # use rocket::get;
228 /// # type Pool = rocket_db_pools::sqlx::SqlitePool;
229 /// use rocket_db_pools::{Database, Connection};
230 ///
231 /// #[derive(Database)]
232 /// #[database("db")]
233 /// struct Db(Pool);
234 ///
235 /// #[get("/")]
236 /// async fn db_op(db: Connection<Db>) {
237 /// let inner = db.into_inner();
238 /// }
239 /// # }
240 /// ```
241 pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
242 self.0
243 }
244}
245
246#[rocket::async_trait]
247impl<D: Database> Fairing for Initializer<D> {
248 fn info(&self) -> Info {
249 Info {
250 name: self.0.unwrap_or(std::any::type_name::<Self>()),
251 kind: Kind::Ignite | Kind::Shutdown,
252 }
253 }
254
255 async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
256 let workers: usize = rocket.figment()
257 .extract_inner(rocket::Config::WORKERS)
258 .unwrap_or_else(|_| rocket::Config::default().workers);
259
260 let figment = rocket.figment()
261 .focus(&format!("databases.{}", D::NAME))
262 .join(Serialized::default("max_connections", workers * 4))
263 .join(Serialized::default("connect_timeout", 5));
264
265 match <D::Pool>::init(&figment).await {
266 Ok(pool) => Ok(rocket.manage(D::from(pool))),
267 Err(e) => {
268 error!("database initialization failed: {e}");
269 Err(rocket)
270 }
271 }
272 }
273
274 async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
275 if let Some(db) = D::fetch(rocket) {
276 db.close().await;
277 }
278 }
279}
280
281#[rocket::async_trait]
282impl<'r, D: Database> FromRequest<'r> for Connection<D> {
283 type Error = Option<<D::Pool as Pool>::Error>;
284
285 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
286 match D::fetch(req.rocket()) {
287 Some(db) => match db.get().await {
288 Ok(conn) => Outcome::Success(Connection(conn)),
289 Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
290 },
291 None => Outcome::Error((Status::InternalServerError, None)),
292 }
293 }
294}
295
296impl<D: Database> Sentinel for Connection<D> {
297 fn abort(rocket: &Rocket<Ignite>) -> bool {
298 D::fetch(rocket).is_none()
299 }
300}
301
302impl<D: Database> Deref for Connection<D> {
303 type Target = <D::Pool as Pool>::Connection;
304
305 fn deref(&self) -> &Self::Target {
306 &self.0
307 }
308}
309
310impl<D: Database> DerefMut for Connection<D> {
311 fn deref_mut(&mut self) -> &mut Self::Target {
312 &mut self.0
313 }
314}