rocket/
lifecycle.rs

1use futures::future::{FutureExt, Future};
2
3use crate::trace::Trace;
4use crate::util::Formatter;
5use crate::data::IoHandler;
6use crate::http::{Method, Status, Header};
7use crate::outcome::Outcome;
8use crate::form::Form;
9use crate::{route, catcher, Rocket, Orbit, Request, Response, Data};
10
11// A token returned to force the execution of one method before another.
12pub(crate) struct RequestToken;
13
14async fn catch_handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
15    where F: FnOnce() -> Fut, Fut: Future<Output = T>,
16{
17    macro_rules! panic_info {
18        ($name:expr, $e:expr) => {{
19            error!(handler = name.as_ref().map(display),
20                "handler panicked\n\
21                This is an application bug.\n\
22                A panic in Rust must be treated as an exceptional event.\n\
23                Panicking is not a suitable error handling mechanism.\n\
24                Unwinding, the result of a panic, is an expensive operation.\n\
25                Panics will degrade application performance.\n\
26                Instead of panicking, return `Option` and/or `Result`.\n\
27                Values of either type can be returned directly from handlers.\n\
28                A panic is treated as an internal server error.");
29
30            $e
31        }}
32    }
33
34    let run = std::panic::AssertUnwindSafe(run);
35    let fut = std::panic::catch_unwind(run)
36        .map_err(|e| panic_info!(name, e))
37        .ok()?;
38
39    std::panic::AssertUnwindSafe(fut)
40        .catch_unwind()
41        .await
42        .map_err(|e| panic_info!(name, e))
43        .ok()
44}
45
46impl Rocket<Orbit> {
47    /// Preprocess the request for Rocket things. Currently, this means:
48    ///
49    ///   * Rewriting the method in the request if _method form field exists.
50    ///   * Run the request fairings.
51    ///
52    /// This is the only place during lifecycle processing that `Request` is
53    /// mutable. Keep this in-sync with the `FromForm` derive.
54    pub(crate) async fn preprocess(
55        &self,
56        req: &mut Request<'_>,
57        data: &mut Data<'_>
58    ) -> RequestToken {
59        // Check if this is a form and if the form contains the special _method
60        // field which we use to reinterpret the request's method.
61        if req.method() == Method::Post && req.content_type().map_or(false, |v| v.is_form()) {
62            let peek_buffer = data.peek(32).await;
63            let method = std::str::from_utf8(peek_buffer).ok()
64                .and_then(|raw_form| Form::values(raw_form).next())
65                .filter(|field| field.name == "_method")
66                .and_then(|field| field.value.parse().ok());
67
68            if let Some(method) = method {
69                req.set_method(method);
70            }
71        }
72
73        // Run request fairings.
74        self.fairings.handle_request(req, data).await;
75
76        RequestToken
77    }
78
79    /// Dispatches the request to the router and processes the outcome to
80    /// produce a response. If the initial outcome is a *forward* and the
81    /// request was a HEAD request, the request is rewritten and rerouted as a
82    /// GET. This is automatic HEAD handling.
83    ///
84    /// After performing the above, if the outcome is a forward or error, the
85    /// appropriate error catcher is invoked to produce the response. Otherwise,
86    /// the successful response is used directly.
87    ///
88    /// Finally, new cookies in the cookie jar are added to the response,
89    /// Rocket-specific headers are written, and response fairings are run. Note
90    /// that error responses have special cookie handling. See `handle_error`.
91    pub(crate) async fn dispatch<'r, 's: 'r>(
92        &'s self,
93        _token: RequestToken,
94        request: &'r Request<'s>,
95        data: Data<'r>,
96        // io_stream: impl Future<Output = io::Result<IoStream>> + Send,
97    ) -> Response<'r> {
98        // Remember if the request is `HEAD` for later body stripping.
99        let was_head_request = request.method() == Method::Head;
100
101        // Route the request and run the user's handlers.
102        let mut response = match self.route(request, data).await {
103            Outcome::Success(response) => response,
104            Outcome::Forward((data, _)) if request.method() == Method::Head => {
105                tracing::Span::current().record("autohandled", true);
106
107                // Dispatch the request again with Method `GET`.
108                request._set_method(Method::Get);
109                match self.route(request, data).await {
110                    Outcome::Success(response) => response,
111                    Outcome::Error(status) => self.dispatch_error(status, request).await,
112                    Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
113                }
114            }
115            Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
116            Outcome::Error(status) => self.dispatch_error(status, request).await,
117        };
118
119        // Set the cookies. Note that error responses will only include cookies
120        // set by the error handler. See `handle_error` for more.
121        let delta_jar = request.cookies().take_delta_jar();
122        for cookie in delta_jar.delta() {
123            response.adjoin_header(cookie);
124        }
125
126        // Add a default 'Server' header if it isn't already there.
127        // TODO: If removing Hyper, write out `Date` header too.
128        if let Some(ident) = request.rocket().config.ident.as_str() {
129            if !response.headers().contains("Server") {
130                response.set_header(Header::new("Server", ident));
131            }
132        }
133
134        // Run the response fairings.
135        self.fairings.handle_response(request, &mut response).await;
136
137        // Strip the body if this is a `HEAD` request or a 304 response.
138        if was_head_request || response.status() == Status::NotModified {
139            response.strip_body();
140        }
141
142        // If the response status is 204, strip the body and its size (no
143        // content-length header). Otherwise, check if the body is sized and use
144        // that size to set the content-length headr appropriately.
145        if response.status() == Status::NoContent {
146            *response.body_mut() = crate::response::Body::unsized_none();
147        } else if let Some(size) = response.body_mut().size().await {
148            response.set_raw_header("Content-Length", size.to_string());
149        }
150
151        if let Some(alt_svc) = request.rocket().alt_svc() {
152            response.set_raw_header("Alt-Svc", alt_svc);
153        }
154
155        // TODO: Should upgrades be handled here? We miss them on local clients.
156        response
157    }
158
159    pub(crate) fn extract_io_handler<'r>(
160        request: &'r Request<'_>,
161        response: &mut Response<'r>,
162        // io_stream: impl Future<Output = io::Result<IoStream>> + Send,
163    ) -> Option<(String, Box<dyn IoHandler + 'r>)> {
164        let upgrades = request.headers().get("upgrade");
165        let Ok(upgrade) = response.search_upgrades(upgrades) else {
166            info!(
167                upgrades = %Formatter(|f| f.debug_list()
168                    .entries(request.headers().get("upgrade"))
169                    .finish()),
170                "request wants upgrade but no i/o handler matched\n\
171                refusing to upgrade request"
172            );
173
174            return None;
175        };
176
177        if let Some((proto, io_handler)) = upgrade {
178            let proto = proto.to_string();
179            response.set_status(Status::SwitchingProtocols);
180            response.set_raw_header("Connection", "Upgrade");
181            response.set_raw_header("Upgrade", proto.clone());
182            return Some((proto, io_handler));
183        }
184
185        None
186    }
187
188    /// Calls the handler for each matching route until one of the handlers
189    /// returns success or error, or there are no additional routes to try, in
190    /// which case a `Forward` with the last forwarding state is returned.
191    #[inline]
192    #[tracing::instrument("routing", skip_all, fields(
193        method = %request.method(),
194        uri = %request.uri(),
195        format = request.format().map(display),
196    ))]
197    async fn route<'s, 'r: 's>(
198        &'s self,
199        request: &'r Request<'s>,
200        mut data: Data<'r>,
201    ) -> route::Outcome<'r> {
202        // Go through all matching routes until we fail or succeed or run out of
203        // routes to try, in which case we forward with the last status.
204        let mut status = Status::NotFound;
205        for route in self.router.route(request) {
206            // Retrieve and set the requests parameters.
207            route.trace_info();
208            request.set_route(route);
209
210            let name = route.name.as_deref();
211            let outcome = catch_handle(name, || route.handler.handle(request, data)).await
212                .unwrap_or(Outcome::Error(Status::InternalServerError));
213
214            // Check if the request processing completed (Some) or if the
215            // request needs to be forwarded. If it does, continue the loop
216            outcome.trace_info();
217            match outcome {
218                o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
219                Outcome::Forward(forwarded) => (data, status) = forwarded,
220            }
221        }
222
223        Outcome::Forward((data, status))
224    }
225
226    // Invokes the catcher for `status`. Returns the response on success.
227    //
228    // Resets the cookie jar delta state to prevent any modifications from
229    // earlier unsuccessful paths from being reflected in the error response.
230    //
231    // On catcher error, the 500 error catcher is attempted. If _that_ errors,
232    // the (infallible) default 500 error cather is used.
233    #[tracing::instrument("catching", skip_all, fields(status = status.code, uri = %req.uri()))]
234    pub(crate) async fn dispatch_error<'r, 's: 'r>(
235        &'s self,
236        mut status: Status,
237        req: &'r Request<'s>
238    ) -> Response<'r> {
239        // We may wish to relax this in the future.
240        req.cookies().reset_delta();
241
242        loop {
243            // Dispatch to the `status` catcher.
244            match self.invoke_catcher(status, req).await {
245                Ok(r) => return r,
246                // If the catcher failed, try `500` catcher, unless this is it.
247                Err(e) if status.code != 500 => {
248                    warn!(status = e.map(|r| r.code), "catcher failed: trying 500 catcher");
249                    status = Status::InternalServerError;
250                }
251                // The 500 catcher failed. There's no recourse. Use default.
252                Err(e) => {
253                    error!(status = e.map(|r| r.code), "500 catcher failed");
254                    return catcher::default_handler(Status::InternalServerError, req);
255                }
256            }
257        }
258    }
259
260    /// Invokes the handler with `req` for catcher with status `status`.
261    ///
262    /// In order of preference, invoked handler is:
263    ///   * the user's registered handler for `status`
264    ///   * the user's registered `default` handler
265    ///   * Rocket's default handler for `status`
266    ///
267    /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
268    /// if the handler ran to completion but failed. Returns `Ok(None)` if the
269    /// handler panicked while executing.
270    async fn invoke_catcher<'s, 'r: 's>(
271        &'s self,
272        status: Status,
273        req: &'r Request<'s>
274    ) -> Result<Response<'r>, Option<Status>> {
275        if let Some(catcher) = self.router.catch(status, req) {
276            catcher.trace_info();
277            catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, req)).await
278                .map(|result| result.map_err(Some))
279                .unwrap_or_else(|| Err(None))
280        } else {
281            info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code,
282                "no registered catcher: using Rocket default");
283            Ok(catcher::default_handler(status, req))
284        }
285    }
286}