rocket/fairing/
ad_hoc.rs

1use parking_lot::Mutex;
2use futures::future::{Future, BoxFuture, FutureExt};
3
4use crate::{Rocket, Request, Response, Data, Build, Orbit};
5use crate::fairing::{Fairing, Kind, Info, Result};
6use crate::route::RouteUri;
7use crate::trace::Trace;
8
9/// A ad-hoc fairing that can be created from a function or closure.
10///
11/// This enum can be used to create a fairing from a simple function or closure
12/// without creating a new structure or implementing `Fairing` directly.
13///
14/// # Usage
15///
16/// Use [`AdHoc::on_ignite`], [`AdHoc::on_liftoff`], [`AdHoc::on_request()`], or
17/// [`AdHoc::on_response()`] to create an `AdHoc` structure from a function or
18/// closure. Then, simply attach the structure to the `Rocket` instance.
19///
20/// # Example
21///
22/// The following snippet creates a `Rocket` instance with two ad-hoc fairings.
23/// The first, a liftoff fairing named "Liftoff Printer", simply prints a message
24/// indicating that Rocket has launched. The second named "Put Rewriter", a
25/// request fairing, rewrites the method of all requests to be `PUT`.
26///
27/// ```rust
28/// use rocket::fairing::AdHoc;
29/// use rocket::http::Method;
30///
31/// rocket::build()
32///     .attach(AdHoc::on_liftoff("Liftoff Printer", |_| Box::pin(async move {
33///         println!("...annnddd we have liftoff!");
34///     })))
35///     .attach(AdHoc::on_request("Put Rewriter", |req, _| Box::pin(async move {
36///         req.set_method(Method::Put);
37///     })));
38/// ```
39pub struct AdHoc {
40    name: &'static str,
41    kind: AdHocKind,
42}
43
44struct Once<F: ?Sized>(Mutex<Option<Box<F>>>);
45
46impl<F: ?Sized> Once<F> {
47    fn new(f: Box<F>) -> Self { Once(Mutex::new(Some(f))) }
48
49    #[track_caller]
50    fn take(&self) -> Box<F> {
51        self.0.lock().take().expect("Once::take() called once")
52    }
53}
54
55enum AdHocKind {
56    /// An ad-hoc **ignite** fairing. Called during ignition.
57    Ignite(Once<dyn FnOnce(Rocket<Build>) -> BoxFuture<'static, Result> + Send + 'static>),
58
59    /// An ad-hoc **liftoff** fairing. Called just after Rocket launches.
60    Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
61
62    /// An ad-hoc **request** fairing. Called when a request is received.
63    Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>)
64        -> BoxFuture<'a, ()> + Send + Sync + 'static>),
65
66    /// An ad-hoc **response** fairing. Called when a response is ready to be
67    /// sent to a client.
68    Response(Box<dyn for<'r, 'b> Fn(&'r Request<'_>, &'b mut Response<'r>)
69        -> BoxFuture<'b, ()> + Send + Sync + 'static>),
70
71    /// An ad-hoc **shutdown** fairing. Called on shutdown.
72    Shutdown(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
73}
74
75impl AdHoc {
76    /// Constructs an `AdHoc` ignite fairing named `name`. The function `f` will
77    /// be called by Rocket during the [`Rocket::ignite()`] phase.
78    ///
79    /// This version of an `AdHoc` ignite fairing cannot abort ignite. For a
80    /// fallible version that can, see [`AdHoc::try_on_ignite()`].
81    ///
82    /// # Example
83    ///
84    /// ```rust
85    /// use rocket::fairing::AdHoc;
86    ///
87    /// // The no-op ignite fairing.
88    /// let fairing = AdHoc::on_ignite("Boom!", |rocket| async move {
89    ///     rocket
90    /// });
91    /// ```
92    pub fn on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
93        where F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
94              Fut: Future<Output = Rocket<Build>> + Send + 'static,
95    {
96        AdHoc::try_on_ignite(name, |rocket| f(rocket).map(Ok))
97    }
98
99    /// Constructs an `AdHoc` ignite fairing named `name`. The function `f` will
100    /// be called by Rocket during the [`Rocket::ignite()`] phase. Returning an
101    /// `Err` aborts ignition and thus launch.
102    ///
103    /// For an infallible version, see [`AdHoc::on_ignite()`].
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use rocket::fairing::AdHoc;
109    ///
110    /// // The no-op try ignite fairing.
111    /// let fairing = AdHoc::try_on_ignite("No-Op", |rocket| async { Ok(rocket) });
112    /// ```
113    pub fn try_on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
114        where F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
115              Fut: Future<Output = Result> + Send + 'static,
116    {
117        AdHoc { name, kind: AdHocKind::Ignite(Once::new(Box::new(|r| f(r).boxed()))) }
118    }
119
120    /// Constructs an `AdHoc` liftoff fairing named `name`. The function `f`
121    /// will be called by Rocket just after [`Rocket::launch()`].
122    ///
123    /// # Example
124    ///
125    /// ```rust
126    /// use rocket::fairing::AdHoc;
127    ///
128    /// // A fairing that prints a message just before launching.
129    /// let fairing = AdHoc::on_liftoff("Boom!", |_| Box::pin(async move {
130    ///     println!("Rocket has lifted off!");
131    /// }));
132    /// ```
133    pub fn on_liftoff<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
134        where F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>
135    {
136        AdHoc { name, kind: AdHocKind::Liftoff(Once::new(Box::new(f))) }
137    }
138
139    /// Constructs an `AdHoc` request fairing named `name`. The function `f`
140    /// will be called and the returned `Future` will be `await`ed by Rocket
141    /// when a new request is received.
142    ///
143    /// # Example
144    ///
145    /// ```rust
146    /// use rocket::fairing::AdHoc;
147    ///
148    /// // The no-op request fairing.
149    /// let fairing = AdHoc::on_request("Dummy", |req, data| {
150    ///     Box::pin(async move {
151    ///         // do something with the request and data...
152    /// #       let (_, _) = (req, data);
153    ///     })
154    /// });
155    /// ```
156    pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
157        where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>
158    {
159        AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
160    }
161
162    // FIXME(rustc): We'd like to allow passing `async fn` to these methods...
163    // https://github.com/rust-lang/rust/issues/64552#issuecomment-666084589
164
165    /// Constructs an `AdHoc` response fairing named `name`. The function `f`
166    /// will be called and the returned `Future` will be `await`ed by Rocket
167    /// when a response is ready to be sent.
168    ///
169    /// # Example
170    ///
171    /// ```rust
172    /// use rocket::fairing::AdHoc;
173    ///
174    /// // The no-op response fairing.
175    /// let fairing = AdHoc::on_response("Dummy", |req, resp| {
176    ///     Box::pin(async move {
177    ///         // do something with the request and pending response...
178    /// #       let (_, _) = (req, resp);
179    ///     })
180    /// });
181    /// ```
182    pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
183        where F: for<'b, 'r> Fn(&'r Request<'_>, &'b mut Response<'r>) -> BoxFuture<'b, ()>
184    {
185        AdHoc { name, kind: AdHocKind::Response(Box::new(f)) }
186    }
187
188    /// Constructs an `AdHoc` shutdown fairing named `name`. The function `f`
189    /// will be called by Rocket when [shutdown is triggered].
190    ///
191    /// [shutdown is triggered]: crate::config::ShutdownConfig#triggers
192    ///
193    /// # Example
194    ///
195    /// ```rust
196    /// use rocket::fairing::AdHoc;
197    ///
198    /// // A fairing that prints a message just before launching.
199    /// let fairing = AdHoc::on_shutdown("Bye!", |_| Box::pin(async move {
200    ///     println!("Rocket is on its way back!");
201    /// }));
202    /// ```
203    pub fn on_shutdown<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
204        where F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>
205    {
206        AdHoc { name, kind: AdHocKind::Shutdown(Once::new(Box::new(f))) }
207    }
208
209    /// Constructs an `AdHoc` launch fairing that extracts a configuration of
210    /// type `T` from the configured provider and stores it in managed state. If
211    /// extractions fails, pretty-prints the error message and aborts launch.
212    ///
213    /// # Example
214    ///
215    /// ```rust
216    /// # use rocket::launch;
217    /// use serde::Deserialize;
218    /// use rocket::fairing::AdHoc;
219    ///
220    /// #[derive(Deserialize)]
221    /// struct Config {
222    ///     field: String,
223    ///     other: usize,
224    ///     /* and so on.. */
225    /// }
226    ///
227    /// #[launch]
228    /// fn rocket() -> _ {
229    ///     rocket::build().attach(AdHoc::config::<Config>())
230    /// }
231    /// ```
232    pub fn config<'de, T>() -> AdHoc
233        where T: serde::Deserialize<'de> + Send + Sync + 'static
234    {
235        AdHoc::try_on_ignite(std::any::type_name::<T>(), |rocket| async {
236            let app_config = match rocket.figment().extract::<T>() {
237                Ok(config) => config,
238                Err(e) => {
239                    e.trace_error();
240                    return Err(rocket);
241                }
242            };
243
244            Ok(rocket.manage(app_config))
245        })
246    }
247
248    /// Constructs an `AdHoc` request fairing that strips trailing slashes from
249    /// all URIs in all incoming requests.
250    ///
251    /// The fairing returned by this method is intended largely for applications
252    /// that migrated from Rocket v0.4 to Rocket v0.5. In Rocket v0.4, requests
253    /// with a trailing slash in the URI were treated as if the trailing slash
254    /// were not present. For example, the request URI `/foo/` would match the
255    /// route `/<a>` with `a = foo`. If the application depended on this
256    /// behavior, say by using URIs with previously innocuous trailing slashes
257    /// in an external application, requests will not be routed as expected.
258    ///
259    /// This fairing resolves this issue by stripping a trailing slash, if any,
260    /// in all incoming URIs. When it does so, it logs a warning. It is
261    /// recommended to use this fairing as a stop-gap measure instead of a
262    /// permanent resolution, if possible.
263    //
264    /// # Example
265    ///
266    /// With the fairing attached, request URIs have a trailing slash stripped:
267    ///
268    /// ```rust
269    /// # #[macro_use] extern crate rocket;
270    /// use rocket::local::blocking::Client;
271    /// use rocket::fairing::AdHoc;
272    ///
273    /// #[get("/<param>")]
274    /// fn foo(param: &str) -> &str {
275    ///     param
276    /// }
277    ///
278    /// #[launch]
279    /// fn rocket() -> _ {
280    ///     rocket::build()
281    ///         .mount("/", routes![foo])
282    ///         .attach(AdHoc::uri_normalizer())
283    /// }
284    ///
285    /// # let client = Client::debug(rocket()).unwrap();
286    /// let response = client.get("/bar/").dispatch();
287    /// assert_eq!(response.into_string().unwrap(), "bar");
288    /// ```
289    ///
290    /// Without it, request URIs are unchanged and routed normally:
291    ///
292    /// ```rust
293    /// # #[macro_use] extern crate rocket;
294    /// use rocket::local::blocking::Client;
295    /// use rocket::fairing::AdHoc;
296    ///
297    /// #[get("/<param>")]
298    /// fn foo(param: &str) -> &str {
299    ///     param
300    /// }
301    ///
302    /// #[launch]
303    /// fn rocket() -> _ {
304    ///     rocket::build().mount("/", routes![foo])
305    /// }
306    ///
307    /// # let client = Client::debug(rocket()).unwrap();
308    /// let response = client.get("/bar/").dispatch();
309    /// assert!(response.status().class().is_client_error());
310    ///
311    /// let response = client.get("/bar").dispatch();
312    /// assert_eq!(response.into_string().unwrap(), "bar");
313    /// ```
314    // #[deprecated(since = "0.7", note = "routing from Rocket 0.6 is now standard")]
315    pub fn uri_normalizer() -> impl Fairing {
316        #[derive(Default)]
317        struct Normalizer {
318            routes: state::InitCell<Vec<crate::Route>>,
319        }
320
321        impl Normalizer {
322            fn routes(&self, rocket: &Rocket<Orbit>) -> &[crate::Route] {
323                self.routes.get_or_init(|| {
324                    rocket.routes()
325                        .filter(|r| r.uri.has_trailing_slash())
326                        .cloned()
327                        .collect()
328                })
329            }
330        }
331
332        #[crate::async_trait]
333        impl Fairing for Normalizer {
334            fn info(&self) -> Info {
335                Info { name: "URI Normalizer", kind: Kind::Ignite | Kind::Liftoff | Kind::Request }
336            }
337
338            async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
339                // We want a route like `/foo/<bar..>` to match a request for
340                // `/foo` as it would have before. While we could check if a
341                // route is mounted that would cause this match and then rewrite
342                // the request URI as `/foo/`, doing so is expensive and
343                // potentially incorrect due to request guards and ranking.
344                //
345                // Instead, we generate a new route with URI `/foo` with the
346                // same rank and handler as the `/foo/<bar..>` route and mount
347                // it to this instance of `rocket`. This preserves the previous
348                // matching while still checking request guards.
349                let normalized_trailing = rocket.routes()
350                    .filter(|r| r.uri.metadata.dynamic_trail)
351                    .filter(|r| r.uri.path().segments().num() > 1)
352                    .filter_map(|route| {
353                        let path = route.uri.unmounted().path();
354                        let new_path = path.as_str()
355                            .rsplit_once('/')
356                            .map(|(prefix, _)| prefix)
357                            .filter(|path| !path.is_empty())
358                            .unwrap_or("/");
359
360                        let base = route.uri.base().as_str();
361                        let uri = match route.uri.unmounted().query() {
362                            Some(q) => format!("{}?{}", new_path, q),
363                            None => new_path.to_string()
364                        };
365
366                        let mut route = route.clone();
367                        route.uri = RouteUri::try_new(base, &uri).ok()?;
368                        route.name = route.name.map(|r| format!("{} [normalized]", r).into());
369                        Some(route)
370                    })
371                    .collect::<Vec<_>>();
372
373                Ok(rocket.mount("/", normalized_trailing))
374            }
375
376            async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
377                let _ = self.routes(rocket);
378            }
379
380            async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
381                // If the URI has no trailing slash, it routes as before.
382                if req.uri().is_normalized_nontrailing() {
383                    return
384                }
385
386                // Otherwise, check if there's a route that matches the request
387                // with a trailing slash. If there is, leave the request alone.
388                // This allows incremental compatibility updates. Otherwise,
389                // rewrite the request URI to remove the `/`.
390                if !self.routes(req.rocket()).iter().any(|r| r.matches(req)) {
391                    let normalized = req.uri().clone().into_normalized_nontrailing();
392                    warn!(original = %req.uri(), %normalized,
393                        "incoming request URI normalized for compatibility");
394                    req.set_uri(normalized);
395                }
396            }
397        }
398
399        Normalizer::default()
400    }
401}
402
403#[crate::async_trait]
404impl Fairing for AdHoc {
405    fn info(&self) -> Info {
406        let kind = match self.kind {
407            AdHocKind::Ignite(_) => Kind::Ignite,
408            AdHocKind::Liftoff(_) => Kind::Liftoff,
409            AdHocKind::Request(_) => Kind::Request,
410            AdHocKind::Response(_) => Kind::Response,
411            AdHocKind::Shutdown(_) => Kind::Shutdown,
412        };
413
414        Info { name: self.name, kind }
415    }
416
417    async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
418        match self.kind {
419            AdHocKind::Ignite(ref f) => (f.take())(rocket).await,
420            _ => Ok(rocket)
421        }
422    }
423
424    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
425        if let AdHocKind::Liftoff(ref f) = self.kind {
426            (f.take())(rocket).await
427        }
428    }
429
430    async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
431        if let AdHocKind::Request(ref f) = self.kind {
432            f(req, data).await
433        }
434    }
435
436    async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
437        if let AdHocKind::Response(ref f) = self.kind {
438            f(req, res).await
439        }
440    }
441
442    async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
443        if let AdHocKind::Shutdown(ref f) = self.kind {
444            (f.take())(rocket).await
445        }
446    }
447}