1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use rocket::{Phase, Rocket, Ignite, Sentinel};
5use rocket::fairing::{AdHoc, Fairing};
6use rocket::request::{Request, Outcome, FromRequest};
7use rocket::outcome::IntoOutcome;
8use rocket::http::Status;
9
10use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
11use rocket::tokio::time::timeout;
12
13use crate::{Config, Poolable, Error};
14
15#[doc(hidden)]
20pub struct ConnectionPool<K, C: Poolable> {
21 config: Config,
22 pool: Option<r2d2::Pool<C::Manager>>,
24 semaphore: Arc<Semaphore>,
25 _marker: PhantomData<fn() -> K>,
26}
27
28impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
29 fn clone(&self) -> Self {
30 ConnectionPool {
31 config: self.config.clone(),
32 pool: self.pool.clone(),
33 semaphore: self.semaphore.clone(),
34 _marker: PhantomData
35 }
36 }
37}
38
39#[doc(hidden)]
44pub struct Connection<K, C: Poolable> {
45 connection: Arc<Mutex<Option<r2d2::PooledConnection<C::Manager>>>>,
46 permit: Option<OwnedSemaphorePermit>,
47 _marker: PhantomData<fn() -> K>,
48}
49
50async fn run_blocking<F, R>(job: F) -> R
52 where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
53{
54 match tokio::task::spawn_blocking(job).await {
55 Ok(ret) => ret,
56 Err(e) => match e.try_into_panic() {
57 Ok(panic) => std::panic::resume_unwind(panic),
58 Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
59 }
60 }
61}
62
63macro_rules! dberr {
64 ($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => ({
65 rocket::error!(concat!("database ", $msg, " error for pool named `{}`"), $db_name);
66 error_!($efmt, $error);
67 return Err($rocket);
68 });
69}
70
71impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
72 pub fn fairing(fairing_name: &'static str, db: &'static str) -> impl Fairing {
73 AdHoc::try_on_ignite(fairing_name, move |rocket| async move {
74 run_blocking(move || {
75 let config = match Config::from(db, &rocket) {
76 Ok(config) => config,
77 Err(e) => dberr!("config", db, "{}", e, rocket),
78 };
79
80 let pool_size = config.pool_size;
81 match C::pool(db, &rocket) {
82 Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
83 config,
84 pool: Some(pool),
85 semaphore: Arc::new(Semaphore::new(pool_size as usize)),
86 _marker: PhantomData,
87 })),
88 Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
89 Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
90 Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
91 }
92 }).await
93 })
94 }
95
96 pub async fn get(&self) -> Option<Connection<K, C>> {
97 let duration = std::time::Duration::from_secs(self.config.timeout as u64);
98 let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
99 Ok(p) => p.expect("internal invariant broken: semaphore should not be closed"),
100 Err(_) => {
101 error_!("database connection retrieval timed out");
102 return None;
103 }
104 };
105
106 let pool = self.pool.as_ref().cloned()
107 .expect("internal invariant broken: self.pool is Some");
108
109 match run_blocking(move || pool.get_timeout(duration)).await {
110 Ok(c) => Some(Connection {
111 connection: Arc::new(Mutex::new(Some(c))),
112 permit: Some(permit),
113 _marker: PhantomData,
114 }),
115 Err(e) => {
116 error_!("failed to get a database connection: {}", e);
117 None
118 }
119 }
120 }
121
122 #[inline]
123 pub async fn get_one<P: Phase>(rocket: &Rocket<P>) -> Option<Connection<K, C>> {
124 match Self::pool(rocket) {
125 Some(pool) => match pool.get().await {
126 Some(conn) => Some(conn),
127 None => {
128 error_!("no connections available for `{}`", std::any::type_name::<K>());
129 None
130 }
131 },
132 None => {
133 error_!("missing database fairing for `{}`", std::any::type_name::<K>());
134 None
135 }
136 }
137 }
138
139 #[inline]
140 pub fn pool<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
141 rocket.state::<Self>()
142 }
143}
144
145impl<K: 'static, C: Poolable> Connection<K, C> {
146 pub async fn run<F, R>(&self, f: F) -> R
147 where F: FnOnce(&mut C) -> R + Send + 'static,
148 R: Send + 'static,
149 {
150 let connection = self.connection.clone();
156
157 run_blocking(move || {
160 let mut connection = tokio::runtime::Handle::current().block_on(async {
163 connection.lock_owned().await
164 });
165
166 let conn = connection.as_mut()
167 .expect("internal invariant broken: self.connection is Some");
168 f(conn)
169 }).await
170 }
171}
172
173impl<K, C: Poolable> Drop for Connection<K, C> {
174 fn drop(&mut self) {
175 let connection = self.connection.clone();
176 let permit = self.permit.take();
177
178 tokio::task::spawn_blocking(move || {
180 let mut connection = tokio::runtime::Handle::current().block_on(async {
181 connection.lock_owned().await
182 });
183
184 if let Some(conn) = connection.take() {
185 drop(conn);
186 }
187
188 drop(permit);
191 });
192 }
193}
194
195impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
196 fn drop(&mut self) {
197 let pool = self.pool.take();
198 if let Ok(handle) = tokio::runtime::Handle::try_current() {
200 handle.spawn_blocking(move || drop(pool));
201 }
202 }
204}
205
206#[rocket::async_trait]
207impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
208 type Error = ();
209
210 #[inline]
211 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
212 match request.rocket().state::<ConnectionPool<K, C>>() {
213 Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())),
214 None => {
215 error_!("Missing database fairing for `{}`", std::any::type_name::<K>());
216 Outcome::Error((Status::InternalServerError, ()))
217 }
218 }
219 }
220}
221
222impl<K: 'static, C: Poolable> Sentinel for Connection<K, C> {
223 fn abort(rocket: &Rocket<Ignite>) -> bool {
224 use rocket::yansi::Paint;
225
226 if rocket.state::<ConnectionPool<K, C>>().is_none() {
227 let conn = std::any::type_name::<K>().primary().bold();
228 error!("requesting `{}` DB connection without attaching `{}{}`.",
229 conn, conn.linger(), "::fairing()".resetting());
230
231 info_!("Attach `{}{}` to use database connection pooling.",
232 conn.linger(), "::fairing()".resetting());
233
234 return true;
235 }
236
237 false
238 }
239}