rocket_sync_db_pools/
connection.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use rocket::{Phase, Rocket, Ignite, Sentinel};
5use rocket::fairing::{AdHoc, Fairing};
6use rocket::request::{Request, Outcome, FromRequest};
7use rocket::outcome::IntoOutcome;
8use rocket::http::Status;
9
10use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
11use rocket::tokio::time::timeout;
12
13use crate::{Config, Poolable, Error};
14
15/// Unstable internal details of generated code for the #[database] attribute.
16///
17/// This type is implemented here instead of in generated code to ensure all
18/// types are properly checked.
19#[doc(hidden)]
20pub struct ConnectionPool<K, C: Poolable> {
21    config: Config,
22    // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
23    pool: Option<r2d2::Pool<C::Manager>>,
24    semaphore: Arc<Semaphore>,
25    _marker: PhantomData<fn() -> K>,
26}
27
28impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
29    fn clone(&self) -> Self {
30        ConnectionPool {
31            config: self.config.clone(),
32            pool: self.pool.clone(),
33            semaphore: self.semaphore.clone(),
34            _marker: PhantomData
35        }
36    }
37}
38
39/// Unstable internal details of generated code for the #[database] attribute.
40///
41/// This type is implemented here instead of in generated code to ensure all
42/// types are properly checked.
43#[doc(hidden)]
44pub struct Connection<K, C: Poolable> {
45    connection: Arc<Mutex<Option<r2d2::PooledConnection<C::Manager>>>>,
46    permit: Option<OwnedSemaphorePermit>,
47    _marker: PhantomData<fn() -> K>,
48}
49
50// A wrapper around spawn_blocking that propagates panics to the calling code.
51async fn run_blocking<F, R>(job: F) -> R
52    where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
53{
54    match tokio::task::spawn_blocking(job).await {
55        Ok(ret) => ret,
56        Err(e) => match e.try_into_panic() {
57            Ok(panic) => std::panic::resume_unwind(panic),
58            Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
59        }
60    }
61}
62
63macro_rules! dberr {
64    ($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => ({
65        rocket::error!(concat!("database ", $msg, " error for pool named `{}`"), $db_name);
66        error_!($efmt, $error);
67        return Err($rocket);
68    });
69}
70
71impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
72    pub fn fairing(fairing_name: &'static str, db: &'static str) -> impl Fairing {
73        AdHoc::try_on_ignite(fairing_name, move |rocket| async move {
74            run_blocking(move || {
75                let config = match Config::from(db, &rocket) {
76                    Ok(config) => config,
77                    Err(e) => dberr!("config", db, "{}", e, rocket),
78                };
79
80                let pool_size = config.pool_size;
81                match C::pool(db, &rocket) {
82                    Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
83                        config,
84                        pool: Some(pool),
85                        semaphore: Arc::new(Semaphore::new(pool_size as usize)),
86                        _marker: PhantomData,
87                    })),
88                    Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
89                    Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
90                    Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
91                }
92            }).await
93        })
94    }
95
96    pub async fn get(&self) -> Option<Connection<K, C>> {
97        let duration = std::time::Duration::from_secs(self.config.timeout as u64);
98        let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
99            Ok(p) => p.expect("internal invariant broken: semaphore should not be closed"),
100            Err(_) => {
101                error_!("database connection retrieval timed out");
102                return None;
103            }
104        };
105
106        let pool = self.pool.as_ref().cloned()
107            .expect("internal invariant broken: self.pool is Some");
108
109        match run_blocking(move || pool.get_timeout(duration)).await {
110            Ok(c) => Some(Connection {
111                connection: Arc::new(Mutex::new(Some(c))),
112                permit: Some(permit),
113                _marker: PhantomData,
114            }),
115            Err(e) => {
116                error_!("failed to get a database connection: {}", e);
117                None
118            }
119        }
120    }
121
122    #[inline]
123    pub async fn get_one<P: Phase>(rocket: &Rocket<P>) -> Option<Connection<K, C>> {
124        match Self::pool(rocket) {
125            Some(pool) => match pool.get().await {
126                Some(conn) => Some(conn),
127                None => {
128                    error_!("no connections available for `{}`", std::any::type_name::<K>());
129                    None
130                }
131            },
132            None => {
133                error_!("missing database fairing for `{}`", std::any::type_name::<K>());
134                None
135            }
136        }
137    }
138
139    #[inline]
140    pub fn pool<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
141        rocket.state::<Self>()
142    }
143}
144
145impl<K: 'static, C: Poolable> Connection<K, C> {
146    pub async fn run<F, R>(&self, f: F) -> R
147        where F: FnOnce(&mut C) -> R + Send + 'static,
148              R: Send + 'static,
149    {
150        // It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
151        // derived from it) never be a variable on the stack at an await point,
152        // where Drop might be called at any time. This causes (synchronous)
153        // Drop to be called from asynchronous code, which some database
154        // wrappers do not or can not handle.
155        let connection = self.connection.clone();
156
157        // Since connection can't be on the stack in an async fn during an
158        // await, we have to spawn a new blocking-safe thread...
159        run_blocking(move || {
160            // And then re-enter the runtime to wait on the async mutex, but in
161            // a blocking fashion.
162            let mut connection = tokio::runtime::Handle::current().block_on(async {
163                connection.lock_owned().await
164            });
165
166            let conn = connection.as_mut()
167                .expect("internal invariant broken: self.connection is Some");
168            f(conn)
169        }).await
170    }
171}
172
173impl<K, C: Poolable> Drop for Connection<K, C> {
174    fn drop(&mut self) {
175        let connection = self.connection.clone();
176        let permit = self.permit.take();
177
178        // See same motivation above for this arrangement of spawn_blocking/block_on
179        tokio::task::spawn_blocking(move || {
180            let mut connection = tokio::runtime::Handle::current().block_on(async {
181                connection.lock_owned().await
182            });
183
184            if let Some(conn) = connection.take() {
185                drop(conn);
186            }
187
188            // Explicitly dropping the permit here so that it's only
189            // released after the connection is.
190            drop(permit);
191        });
192    }
193}
194
195impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
196    fn drop(&mut self) {
197        let pool = self.pool.take();
198        // Only use spawn_blocking if the Tokio runtime is still available
199        if let Ok(handle) = tokio::runtime::Handle::try_current() {
200            handle.spawn_blocking(move || drop(pool));
201        }
202        // Otherwise the pool will be dropped on the current thread
203    }
204}
205
206#[rocket::async_trait]
207impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
208    type Error = ();
209
210    #[inline]
211    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
212        match request.rocket().state::<ConnectionPool<K, C>>() {
213            Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())),
214            None => {
215                error_!("Missing database fairing for `{}`", std::any::type_name::<K>());
216                Outcome::Error((Status::InternalServerError, ()))
217            }
218        }
219    }
220}
221
222impl<K: 'static, C: Poolable> Sentinel for Connection<K, C> {
223    fn abort(rocket: &Rocket<Ignite>) -> bool {
224        use rocket::yansi::Paint;
225
226        if rocket.state::<ConnectionPool<K, C>>().is_none() {
227            let conn = std::any::type_name::<K>().primary().bold();
228            error!("requesting `{}` DB connection without attaching `{}{}`.",
229                conn, conn.linger(), "::fairing()".resetting());
230
231            info_!("Attach `{}{}` to use database connection pooling.",
232                conn.linger(), "::fairing()".resetting());
233
234            return true;
235        }
236
237        false
238    }
239}