1use std::sync::Arc;
2use std::marker::PhantomData;
3
4use rocket::{Phase, Rocket, Ignite, Sentinel};
5use rocket::fairing::{AdHoc, Fairing};
6use rocket::request::{Request, Outcome, FromRequest};
7use rocket::outcome::IntoOutcome;
8use rocket::http::Status;
9use rocket::trace::Trace;
10
11use rocket::tokio::time::timeout;
12use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
13
14use crate::{Config, Poolable, Error};
15
16#[doc(hidden)]
21pub struct ConnectionPool<K, C: Poolable> {
22 config: Config,
23 pool: Option<r2d2::Pool<C::Manager>>,
25 semaphore: Arc<Semaphore>,
26 _marker: PhantomData<fn() -> K>,
27}
28
29impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
30 fn clone(&self) -> Self {
31 ConnectionPool {
32 config: self.config.clone(),
33 pool: self.pool.clone(),
34 semaphore: self.semaphore.clone(),
35 _marker: PhantomData
36 }
37 }
38}
39
40#[doc(hidden)]
45pub struct Connection<K, C: Poolable> {
46 connection: Arc<Mutex<Option<r2d2::PooledConnection<C::Manager>>>>,
47 permit: Option<OwnedSemaphorePermit>,
48 _marker: PhantomData<fn() -> K>,
49}
50
51async fn run_blocking<F, R>(job: F) -> R
53 where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
54{
55 match tokio::task::spawn_blocking(job).await {
56 Ok(ret) => ret,
57 Err(e) => match e.try_into_panic() {
58 Ok(panic) => std::panic::resume_unwind(panic),
59 Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
60 }
61 }
62}
63
64impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
65 pub fn fairing(fairing_name: &'static str, database: &'static str) -> impl Fairing {
66 AdHoc::try_on_ignite(fairing_name, move |rocket| async move {
67 run_blocking(move || {
68 let config = match Config::from(database, &rocket) {
69 Ok(config) => config,
70 Err(e) => {
71 span_error!("database configuration error", database => e.trace_error());
72 return Err(rocket);
73 }
74 };
75
76 let pool_size = config.pool_size;
77 match C::pool(database, &rocket) {
78 Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
79 config,
80 pool: Some(pool),
81 semaphore: Arc::new(Semaphore::new(pool_size as usize)),
82 _marker: PhantomData,
83 })),
84 Err(Error::Config(e)) => {
85 span_error!("database configuration error", database => e.trace_error());
86 Err(rocket)
87 }
88 Err(Error::Pool(reason)) => {
89 error!(database, %reason, "database pool initialization failed");
90 Err(rocket)
91 }
92 Err(Error::Custom(reason)) => {
93 error!(database, ?reason, "database pool failure");
94 Err(rocket)
95 }
96 }
97 }).await
98 })
99 }
100
101 pub async fn get(&self) -> Option<Connection<K, C>> {
102 let type_name = std::any::type_name::<K>();
103 let duration = std::time::Duration::from_secs(self.config.timeout as u64);
104 let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
105 Ok(p) => p.expect("internal invariant broken: semaphore should not be closed"),
106 Err(_) => {
107 error!(type_name, "database connection retrieval timed out");
108 return None;
109 }
110 };
111
112 let pool = self.pool.as_ref().cloned()
113 .expect("internal invariant broken: self.pool is Some");
114
115 match run_blocking(move || pool.get_timeout(duration)).await {
116 Ok(c) => Some(Connection {
117 connection: Arc::new(Mutex::new(Some(c))),
118 permit: Some(permit),
119 _marker: PhantomData,
120 }),
121 Err(e) => {
122 error!(type_name, "failed to get a database connection: {}", e);
123 None
124 }
125 }
126 }
127
128 #[inline]
129 pub async fn get_one<P: Phase>(rocket: &Rocket<P>) -> Option<Connection<K, C>> {
130 match Self::pool(rocket) {
131 Some(pool) => match pool.get().await {
132 Some(conn) => Some(conn),
133 None => {
134 error!("no connections available for `{}`", std::any::type_name::<K>());
135 None
136 }
137 },
138 None => {
139 error!("missing database fairing for `{}`", std::any::type_name::<K>());
140 None
141 }
142 }
143 }
144
145 #[inline]
146 pub fn pool<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
147 rocket.state::<Self>()
148 }
149}
150
151impl<K: 'static, C: Poolable> Connection<K, C> {
152 pub async fn run<F, R>(&self, f: F) -> R
153 where F: FnOnce(&mut C) -> R + Send + 'static,
154 R: Send + 'static,
155 {
156 let connection = self.connection.clone();
162
163 run_blocking(move || {
166 let mut connection = tokio::runtime::Handle::current().block_on(async {
169 connection.lock_owned().await
170 });
171
172 let conn = connection.as_mut()
173 .expect("internal invariant broken: self.connection is Some");
174
175 f(conn)
176 }).await
177 }
178}
179
180impl<K, C: Poolable> Drop for Connection<K, C> {
181 fn drop(&mut self) {
182 let connection = self.connection.clone();
183 let permit = self.permit.take();
184
185 if let Ok(handle) = tokio::runtime::Handle::try_current() {
187 handle.spawn_blocking(move || {
189 let mut connection = tokio::runtime::Handle::current()
190 .block_on(async { connection.lock_owned().await });
191
192 if let Some(conn) = connection.take() {
193 drop(conn);
194 }
195 });
196 } else {
197 warn!(type_name = std::any::type_name::<K>(),
198 "database connection is being dropped outside of an async context\n\
199 this means you have stored a connection beyond a request's lifetime\n\
200 this is not recommended: connections are not valid indefinitely\n\
201 instead, store a connection pool and get connections as needed");
202
203 if let Some(conn) = connection.blocking_lock().take() {
204 drop(conn);
205 }
206 }
207
208 drop(permit);
210 }
211}
212
213impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
214 fn drop(&mut self) {
215 let pool = self.pool.take();
218 if let Ok(handle) = tokio::runtime::Handle::try_current() {
219 handle.spawn_blocking(move || drop(pool));
220 }
221 }
222}
223
224#[rocket::async_trait]
225impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
226 type Error = ();
227
228 #[inline]
229 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
230 match request.rocket().state::<ConnectionPool<K, C>>() {
231 Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())),
232 None => {
233 let conn = std::any::type_name::<K>();
234 error!("`{conn}::fairing()` is not attached\n\
235 the fairing must be attached to use `{conn} in routes.");
236 Outcome::Error((Status::InternalServerError, ()))
237 }
238 }
239 }
240}
241
242impl<K: 'static, C: Poolable> Sentinel for Connection<K, C> {
243 fn abort(rocket: &Rocket<Ignite>) -> bool {
244 if rocket.state::<ConnectionPool<K, C>>().is_none() {
245 let conn = std::any::type_name::<K>();
246 error!("`{conn}::fairing()` is not attached\n\
247 the fairing must be attached to use `{conn} in routes.");
248
249 return true;
250 }
251
252 false
253 }
254}