rocket_sync_db_pools/
connection.rs

1use std::sync::Arc;
2use std::marker::PhantomData;
3
4use rocket::{Phase, Rocket, Ignite, Sentinel};
5use rocket::fairing::{AdHoc, Fairing};
6use rocket::request::{Request, Outcome, FromRequest};
7use rocket::outcome::IntoOutcome;
8use rocket::http::Status;
9use rocket::trace::Trace;
10
11use rocket::tokio::time::timeout;
12use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
13
14use crate::{Config, Poolable, Error};
15
16/// Unstable internal details of generated code for the #[database] attribute.
17///
18/// This type is implemented here instead of in generated code to ensure all
19/// types are properly checked.
20#[doc(hidden)]
21pub struct ConnectionPool<K, C: Poolable> {
22    config: Config,
23    // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
24    pool: Option<r2d2::Pool<C::Manager>>,
25    semaphore: Arc<Semaphore>,
26    _marker: PhantomData<fn() -> K>,
27}
28
29impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
30    fn clone(&self) -> Self {
31        ConnectionPool {
32            config: self.config.clone(),
33            pool: self.pool.clone(),
34            semaphore: self.semaphore.clone(),
35            _marker: PhantomData
36        }
37    }
38}
39
40/// Unstable internal details of generated code for the #[database] attribute.
41///
42/// This type is implemented here instead of in generated code to ensure all
43/// types are properly checked.
44#[doc(hidden)]
45pub struct Connection<K, C: Poolable> {
46    connection: Arc<Mutex<Option<r2d2::PooledConnection<C::Manager>>>>,
47    permit: Option<OwnedSemaphorePermit>,
48    _marker: PhantomData<fn() -> K>,
49}
50
51// A wrapper around spawn_blocking that propagates panics to the calling code.
52async fn run_blocking<F, R>(job: F) -> R
53    where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
54{
55    match tokio::task::spawn_blocking(job).await {
56        Ok(ret) => ret,
57        Err(e) => match e.try_into_panic() {
58            Ok(panic) => std::panic::resume_unwind(panic),
59            Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
60        }
61    }
62}
63
64impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
65    pub fn fairing(fairing_name: &'static str, database: &'static str) -> impl Fairing {
66        AdHoc::try_on_ignite(fairing_name, move |rocket| async move {
67            run_blocking(move || {
68                let config = match Config::from(database, &rocket) {
69                    Ok(config) => config,
70                    Err(e) => {
71                        span_error!("database configuration error", database => e.trace_error());
72                        return Err(rocket);
73                    }
74                };
75
76                let pool_size = config.pool_size;
77                match C::pool(database, &rocket) {
78                    Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
79                        config,
80                        pool: Some(pool),
81                        semaphore: Arc::new(Semaphore::new(pool_size as usize)),
82                        _marker: PhantomData,
83                    })),
84                    Err(Error::Config(e)) => {
85                        span_error!("database configuration error", database => e.trace_error());
86                        Err(rocket)
87                    }
88                    Err(Error::Pool(reason)) => {
89                        error!(database, %reason, "database pool initialization failed");
90                        Err(rocket)
91                    }
92                    Err(Error::Custom(reason)) => {
93                        error!(database, ?reason, "database pool failure");
94                        Err(rocket)
95                    }
96                }
97            }).await
98        })
99    }
100
101    pub async fn get(&self) -> Option<Connection<K, C>> {
102        let type_name = std::any::type_name::<K>();
103        let duration = std::time::Duration::from_secs(self.config.timeout as u64);
104        let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
105            Ok(p) => p.expect("internal invariant broken: semaphore should not be closed"),
106            Err(_) => {
107                error!(type_name, "database connection retrieval timed out");
108                return None;
109            }
110        };
111
112        let pool = self.pool.as_ref().cloned()
113            .expect("internal invariant broken: self.pool is Some");
114
115        match run_blocking(move || pool.get_timeout(duration)).await {
116            Ok(c) => Some(Connection {
117                connection: Arc::new(Mutex::new(Some(c))),
118                permit: Some(permit),
119                _marker: PhantomData,
120            }),
121            Err(e) => {
122                error!(type_name, "failed to get a database connection: {}", e);
123                None
124            }
125        }
126    }
127
128    #[inline]
129    pub async fn get_one<P: Phase>(rocket: &Rocket<P>) -> Option<Connection<K, C>> {
130        match Self::pool(rocket) {
131            Some(pool) => match pool.get().await {
132                Some(conn) => Some(conn),
133                None => {
134                    error!("no connections available for `{}`", std::any::type_name::<K>());
135                    None
136                }
137            },
138            None => {
139                error!("missing database fairing for `{}`", std::any::type_name::<K>());
140                None
141            }
142        }
143    }
144
145    #[inline]
146    pub fn pool<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
147        rocket.state::<Self>()
148    }
149}
150
151impl<K: 'static, C: Poolable> Connection<K, C> {
152    pub async fn run<F, R>(&self, f: F) -> R
153        where F: FnOnce(&mut C) -> R + Send + 'static,
154              R: Send + 'static,
155    {
156        // It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
157        // derived from it) never be a variable on the stack at an await point,
158        // where Drop might be called at any time. This causes (synchronous)
159        // Drop to be called from asynchronous code, which some database
160        // wrappers do not or can not handle.
161        let connection = self.connection.clone();
162
163        // Since connection can't be on the stack in an async fn during an
164        // await, we have to spawn a new blocking-safe thread...
165        run_blocking(move || {
166            // And then re-enter the runtime to wait on the async mutex, but in
167            // a blocking fashion.
168            let mut connection = tokio::runtime::Handle::current().block_on(async {
169                connection.lock_owned().await
170            });
171
172            let conn = connection.as_mut()
173                .expect("internal invariant broken: self.connection is Some");
174
175            f(conn)
176        }).await
177    }
178}
179
180impl<K, C: Poolable> Drop for Connection<K, C> {
181    fn drop(&mut self) {
182        let connection = self.connection.clone();
183        let permit = self.permit.take();
184
185        // Only use spawn_blocking if the Tokio runtime is still available
186        if let Ok(handle) = tokio::runtime::Handle::try_current() {
187            // See above for motivation of this arrangement of spawn_blocking/block_on
188            handle.spawn_blocking(move || {
189                let mut connection = tokio::runtime::Handle::current()
190                    .block_on(async { connection.lock_owned().await });
191
192                if let Some(conn) = connection.take() {
193                    drop(conn);
194                }
195            });
196        } else {
197            warn!(type_name = std::any::type_name::<K>(),
198                "database connection is being dropped outside of an async context\n\
199                this means you have stored a connection beyond a request's lifetime\n\
200                this is not recommended: connections are not valid indefinitely\n\
201                instead, store a connection pool and get connections as needed");
202
203            if let Some(conn) = connection.blocking_lock().take() {
204                drop(conn);
205            }
206        }
207
208        // Explicitly drop permit here to release only after dropping connection.
209        drop(permit);
210    }
211}
212
213impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
214    fn drop(&mut self) {
215        // Use spawn_blocking if the Tokio runtime is still available. Otherwise
216        // the pool will be dropped on the current thread.
217        let pool = self.pool.take();
218        if let Ok(handle) = tokio::runtime::Handle::try_current() {
219            handle.spawn_blocking(move || drop(pool));
220        }
221    }
222}
223
224#[rocket::async_trait]
225impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
226    type Error = ();
227
228    #[inline]
229    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
230        match request.rocket().state::<ConnectionPool<K, C>>() {
231            Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())),
232            None => {
233                let conn = std::any::type_name::<K>();
234                error!("`{conn}::fairing()` is not attached\n\
235                    the fairing must be attached to use `{conn} in routes.");
236                Outcome::Error((Status::InternalServerError, ()))
237            }
238        }
239    }
240}
241
242impl<K: 'static, C: Poolable> Sentinel for Connection<K, C> {
243    fn abort(rocket: &Rocket<Ignite>) -> bool {
244        if rocket.state::<ConnectionPool<K, C>>().is_none() {
245            let conn = std::any::type_name::<K>();
246            error!("`{conn}::fairing()` is not attached\n\
247                the fairing must be attached to use `{conn} in routes.");
248
249            return true;
250        }
251
252        false
253    }
254}