rocket/
server.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4use std::pin::Pin;
5
6use yansi::Paint;
7use tokio::sync::oneshot;
8use tokio::time::sleep;
9use futures::stream::StreamExt;
10use futures::future::{FutureExt, Future, BoxFuture};
11
12use crate::{route, Rocket, Orbit, Request, Response, Data, Config};
13use crate::form::Form;
14use crate::outcome::Outcome;
15use crate::error::{Error, ErrorKind};
16use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
17use crate::request::ConnectionMeta;
18use crate::data::IoHandler;
19
20use crate::http::{hyper, uncased, Method, Status, Header};
21use crate::http::private::{TcpListener, Listener, Connection, Incoming};
22
23// A token returned to force the execution of one method before another.
24pub(crate) struct RequestToken;
25
26async fn handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
27    where F: FnOnce() -> Fut, Fut: Future<Output = T>,
28{
29    use std::panic::AssertUnwindSafe;
30
31    macro_rules! panic_info {
32        ($name:expr, $e:expr) => {{
33            match $name {
34                Some(name) => error_!("Handler {} panicked.", name.primary()),
35                None => error_!("A handler panicked.")
36            };
37
38            info_!("This is an application bug.");
39            info_!("A panic in Rust must be treated as an exceptional event.");
40            info_!("Panicking is not a suitable error handling mechanism.");
41            info_!("Unwinding, the result of a panic, is an expensive operation.");
42            info_!("Panics will degrade application performance.");
43            info_!("Instead of panicking, return `Option` and/or `Result`.");
44            info_!("Values of either type can be returned directly from handlers.");
45            warn_!("A panic is treated as an internal server error.");
46            $e
47        }}
48    }
49
50    let run = AssertUnwindSafe(run);
51    let fut = std::panic::catch_unwind(move || run())
52        .map_err(|e| panic_info!(name, e))
53        .ok()?;
54
55    AssertUnwindSafe(fut)
56        .catch_unwind()
57        .await
58        .map_err(|e| panic_info!(name, e))
59        .ok()
60}
61
62// This function tries to hide all of the Hyper-ness from Rocket. It essentially
63// converts Hyper types into Rocket types, then calls the `dispatch` function,
64// which knows nothing about Hyper. Because responding depends on the
65// `HyperResponse` type, this function does the actual response processing.
66async fn hyper_service_fn(
67    rocket: Arc<Rocket<Orbit>>,
68    conn: ConnectionMeta,
69    mut hyp_req: hyper::Request<hyper::Body>,
70) -> Result<hyper::Response<hyper::Body>, io::Error> {
71    // This future must return a hyper::Response, but the response body might
72    // borrow from the request. Instead, write the body in another future that
73    // sends the response metadata (and a body channel) prior.
74    let (tx, rx) = oneshot::channel();
75
76    #[cfg(not(broken_fmt))]
77    debug!("received request: {:#?}", hyp_req);
78
79    tokio::spawn(async move {
80        // We move the request next, so get the upgrade future now.
81        let pending_upgrade = hyper::upgrade::on(&mut hyp_req);
82
83        // Convert a Hyper request into a Rocket request.
84        let (h_parts, mut h_body) = hyp_req.into_parts();
85        match Request::from_hyp(&rocket, &h_parts, Some(conn)) {
86            Ok(mut req) => {
87                // Convert into Rocket `Data`, dispatch request, write response.
88                let mut data = Data::from(&mut h_body);
89                let token = rocket.preprocess_request(&mut req, &mut data).await;
90                let mut response = rocket.dispatch(token, &req, data).await;
91                let upgrade = response.take_upgrade(req.headers().get("upgrade"));
92                if let Ok(Some((proto, handler))) = upgrade {
93                    rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
94                } else {
95                    if upgrade.is_err() {
96                        warn_!("Request wants upgrade but no I/O handler matched.");
97                        info_!("Request is not being upgraded.");
98                    }
99
100                    rocket.send_response(response, tx).await;
101                }
102            },
103            Err(e) => {
104                warn!("Bad incoming HTTP request.");
105                e.errors.iter().for_each(|e| warn_!("Error: {}.", e));
106                warn_!("Dispatching salvaged request to catcher: {}.", e.request);
107
108                let response = rocket.handle_error(Status::BadRequest, &e.request).await;
109                rocket.send_response(response, tx).await;
110            }
111        }
112    });
113
114    // Receive the response written to `tx` by the task above.
115    rx.await.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
116}
117
118impl Rocket<Orbit> {
119    /// Wrapper around `_send_response` to log a success or error.
120    #[inline]
121    async fn send_response(
122        &self,
123        response: Response<'_>,
124        tx: oneshot::Sender<hyper::Response<hyper::Body>>,
125    ) {
126        let remote_hungup = |e: &io::Error| match e.kind() {
127            | io::ErrorKind::BrokenPipe
128            | io::ErrorKind::ConnectionReset
129            | io::ErrorKind::ConnectionAborted => true,
130            _ => false,
131        };
132
133        match self._send_response(response, tx).await {
134            Ok(()) => info_!("{}", "Response succeeded.".green()),
135            Err(e) if remote_hungup(&e) => warn_!("Remote left: {}.", e),
136            Err(e) => warn_!("Failed to write response: {}.", e),
137        }
138    }
139
140    /// Attempts to create a hyper response from `response` and send it to `tx`.
141    #[inline]
142    async fn _send_response(
143        &self,
144        mut response: Response<'_>,
145        tx: oneshot::Sender<hyper::Response<hyper::Body>>,
146    ) -> io::Result<()> {
147        let mut hyp_res = hyper::Response::builder();
148
149        hyp_res = hyp_res.status(response.status().code);
150        for header in response.headers().iter() {
151            let name = header.name.as_str();
152            let value = header.value.as_bytes();
153            hyp_res = hyp_res.header(name, value);
154        }
155
156        let body = response.body_mut();
157        if let Some(n) = body.size().await {
158            hyp_res = hyp_res.header(hyper::header::CONTENT_LENGTH, n);
159        }
160
161        let (mut sender, hyp_body) = hyper::Body::channel();
162        let hyp_response = hyp_res.body(hyp_body)
163            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
164
165        #[cfg(not(broken_fmt))]
166        debug!("sending response: {:#?}", hyp_response);
167
168        tx.send(hyp_response).map_err(|_| {
169            let msg = "client disconnect before response started";
170            io::Error::new(io::ErrorKind::BrokenPipe, msg)
171        })?;
172
173        let max_chunk_size = body.max_chunk_size();
174        let mut stream = body.into_bytes_stream(max_chunk_size);
175        while let Some(next) = stream.next().await {
176            sender.send_data(next?).await
177                .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
178        }
179
180        Ok(())
181    }
182
183    async fn handle_upgrade<'r>(
184        &self,
185        mut response: Response<'r>,
186        proto: uncased::Uncased<'r>,
187        io_handler: Pin<Box<dyn IoHandler + 'r>>,
188        pending_upgrade: hyper::upgrade::OnUpgrade,
189        tx: oneshot::Sender<hyper::Response<hyper::Body>>,
190    ) {
191        info_!("Upgrading connection to {}.", Paint::white(&proto).bold());
192        response.set_status(Status::SwitchingProtocols);
193        response.set_raw_header("Connection", "Upgrade");
194        response.set_raw_header("Upgrade", proto.clone().into_cow());
195        self.send_response(response, tx).await;
196
197        match pending_upgrade.await {
198            Ok(io_stream) => {
199                info_!("Upgrade successful.");
200                if let Err(e) = io_handler.io(io_stream.into()).await {
201                    if e.kind() == io::ErrorKind::BrokenPipe {
202                        warn!("Upgraded {} I/O handler was closed.", proto);
203                    } else {
204                        error!("Upgraded {} I/O handler failed: {}", proto, e);
205                    }
206                }
207            },
208            Err(e) => {
209                warn!("Response indicated upgrade, but upgrade failed.");
210                warn_!("Upgrade error: {}", e);
211            }
212        }
213    }
214
215    /// Preprocess the request for Rocket things. Currently, this means:
216    ///
217    ///   * Rewriting the method in the request if _method form field exists.
218    ///   * Run the request fairings.
219    ///
220    /// Keep this in-sync with derive_form when preprocessing form fields.
221    pub(crate) async fn preprocess_request(
222        &self,
223        req: &mut Request<'_>,
224        data: &mut Data<'_>
225    ) -> RequestToken {
226        // Check if this is a form and if the form contains the special _method
227        // field which we use to reinterpret the request's method.
228        let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
229        let peek_buffer = data.peek(max_len).await;
230        let is_form = req.content_type().map_or(false, |ct| ct.is_form());
231
232        if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
233            let method = std::str::from_utf8(peek_buffer).ok()
234                .and_then(|raw_form| Form::values(raw_form).next())
235                .filter(|field| field.name == "_method")
236                .and_then(|field| field.value.parse().ok());
237
238            if let Some(method) = method {
239                req._set_method(method);
240            }
241        }
242
243        // Run request fairings.
244        self.fairings.handle_request(req, data).await;
245
246        RequestToken
247    }
248
249    #[inline]
250    pub(crate) async fn dispatch<'s, 'r: 's>(
251        &'s self,
252        _token: RequestToken,
253        request: &'r Request<'s>,
254        data: Data<'r>
255    ) -> Response<'r> {
256        info!("{}:", request);
257
258        // Remember if the request is `HEAD` for later body stripping.
259        let was_head_request = request.method() == Method::Head;
260
261        // Route the request and run the user's handlers.
262        let mut response = self.route_and_process(request, data).await;
263
264        // Add a default 'Server' header if it isn't already there.
265        // TODO: If removing Hyper, write out `Date` header too.
266        if let Some(ident) = request.rocket().config.ident.as_str() {
267            if !response.headers().contains("Server") {
268                response.set_header(Header::new("Server", ident));
269            }
270        }
271
272        // Run the response fairings.
273        self.fairings.handle_response(request, &mut response).await;
274
275        // Strip the body if this is a `HEAD` request.
276        if was_head_request {
277            response.strip_body();
278        }
279
280        response
281    }
282
283    async fn route_and_process<'s, 'r: 's>(
284        &'s self,
285        request: &'r Request<'s>,
286        data: Data<'r>
287    ) -> Response<'r> {
288        let mut response = match self.route(request, data).await {
289            Outcome::Success(response) => response,
290            Outcome::Forward((data, _)) if request.method() == Method::Head => {
291                info_!("Autohandling {} request.", "HEAD".primary().bold());
292
293                // Dispatch the request again with Method `GET`.
294                request._set_method(Method::Get);
295                match self.route(request, data).await {
296                    Outcome::Success(response) => response,
297                    Outcome::Error(status) => self.handle_error(status, request).await,
298                    Outcome::Forward((_, status)) => self.handle_error(status, request).await,
299                }
300            }
301            Outcome::Forward((_, status)) => self.handle_error(status, request).await,
302            Outcome::Error(status) => self.handle_error(status, request).await,
303        };
304
305        // Set the cookies. Note that error responses will only include cookies
306        // set by the error handler. See `handle_error` for more.
307        let delta_jar = request.cookies().take_delta_jar();
308        for cookie in delta_jar.delta() {
309            response.adjoin_header(cookie);
310        }
311
312        response
313    }
314
315    /// Tries to find a `Responder` for a given `request`. It does this by
316    /// routing the request and calling the handler for each matching route
317    /// until one of the handlers returns success or error, or there are no
318    /// additional routes to try (forward). The corresponding outcome for each
319    /// condition is returned.
320    #[inline]
321    async fn route<'s, 'r: 's>(
322        &'s self,
323        request: &'r Request<'s>,
324        mut data: Data<'r>,
325    ) -> route::Outcome<'r> {
326        // Go through all matching routes until we fail or succeed or run out of
327        // routes to try, in which case we forward with the last status.
328        let mut status = Status::NotFound;
329        for route in self.router.route(request) {
330            // Retrieve and set the requests parameters.
331            info_!("Matched: {}", route);
332            request.set_route(route);
333
334            let name = route.name.as_deref();
335            let outcome = handle(name, || route.handler.handle(request, data)).await
336                .unwrap_or(Outcome::Error(Status::InternalServerError));
337
338            // Check if the request processing completed (Some) or if the
339            // request needs to be forwarded. If it does, continue the loop
340            // (None) to try again.
341            info_!("{}", outcome.log_display());
342            match outcome {
343                o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
344                Outcome::Forward(forwarded) => (data, status) = forwarded,
345            }
346        }
347
348        error_!("No matching routes for {}.", request);
349        Outcome::Forward((data, status))
350    }
351
352    /// Invokes the handler with `req` for catcher with status `status`.
353    ///
354    /// In order of preference, invoked handler is:
355    ///   * the user's registered handler for `status`
356    ///   * the user's registered `default` handler
357    ///   * Rocket's default handler for `status`
358    ///
359    /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
360    /// if the handler ran to completion but failed. Returns `Ok(None)` if the
361    /// handler panicked while executing.
362    async fn invoke_catcher<'s, 'r: 's>(
363        &'s self,
364        status: Status,
365        req: &'r Request<'s>
366    ) -> Result<Response<'r>, Option<Status>> {
367        // For now, we reset the delta state to prevent any modifications
368        // from earlier, unsuccessful paths from being reflected in error
369        // response. We may wish to relax this in the future.
370        req.cookies().reset_delta();
371
372        if let Some(catcher) = self.router.catch(status, req) {
373            warn_!("Responding with registered {} catcher.", catcher);
374            let name = catcher.name.as_deref();
375            handle(name, || catcher.handler.handle(status, req)).await
376                .map(|result| result.map_err(Some))
377                .unwrap_or_else(|| Err(None))
378        } else {
379            let code = status.code.blue().bold();
380            warn_!("No {} catcher registered. Using Rocket default.", code);
381            Ok(crate::catcher::default_handler(status, req))
382        }
383    }
384
385    // Invokes the catcher for `status`. Returns the response on success.
386    //
387    // On catcher error, the 500 error catcher is attempted. If _that_ errors,
388    // the (infallible) default 500 error cather is used.
389    pub(crate) async fn handle_error<'s, 'r: 's>(
390        &'s self,
391        mut status: Status,
392        req: &'r Request<'s>
393    ) -> Response<'r> {
394        // Dispatch to the `status` catcher.
395        if let Ok(r) = self.invoke_catcher(status, req).await {
396            return r;
397        }
398
399        // If it fails and it's not a 500, try the 500 catcher.
400        if status != Status::InternalServerError {
401            error_!("Catcher failed. Attempting 500 error catcher.");
402            status = Status::InternalServerError;
403            if let Ok(r) = self.invoke_catcher(status, req).await {
404                return r;
405            }
406        }
407
408        // If it failed again or if it was already a 500, use Rocket's default.
409        error_!("{} catcher failed. Using Rocket default 500.", status.code);
410        crate::catcher::default_handler(Status::InternalServerError, req)
411    }
412
413    pub(crate) async fn default_tcp_http_server<C>(mut self, ready: C) -> Result<Self, Error>
414        where C: for<'a> Fn(&'a Self) -> BoxFuture<'a, ()>
415    {
416        use std::net::ToSocketAddrs;
417
418        // Determine the address we're going to serve on.
419        let addr = format!("{}:{}", self.config.address, self.config.port);
420        let mut addr = addr.to_socket_addrs()
421            .map(|mut addrs| addrs.next().expect(">= 1 socket addr"))
422            .map_err(|e| Error::new(ErrorKind::Io(e)))?;
423
424        #[cfg(feature = "tls")]
425        if self.config.tls_enabled() {
426            if let Some(ref config) = self.config.tls {
427                use crate::http::tls::TlsListener;
428
429                let conf = config.to_native_config().map_err(ErrorKind::Io)?;
430                let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?;
431                addr = l.local_addr().unwrap_or(addr);
432                self.config.address = addr.ip();
433                self.config.port = addr.port();
434                ready(&mut self).await;
435                return self.http_server(l).await;
436            }
437        }
438
439        let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?;
440        addr = l.local_addr().unwrap_or(addr);
441        self.config.address = addr.ip();
442        self.config.port = addr.port();
443        ready(&mut self).await;
444        self.http_server(l).await
445    }
446
447    // TODO.async: Solidify the Listener APIs and make this function public
448    pub(crate) async fn http_server<L>(self, listener: L) -> Result<Self, Error>
449        where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
450    {
451        // Emit a warning if we're not running inside of Rocket's async runtime.
452        if self.config.profile == Config::DEBUG_PROFILE {
453            tokio::task::spawn_blocking(|| {
454                let this  = std::thread::current();
455                if !this.name().map_or(false, |s| s.starts_with("rocket-worker")) {
456                    warn!("Rocket is executing inside of a custom runtime.");
457                    info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
458                    info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
459                }
460            });
461        }
462
463        // Set up cancellable I/O from the given listener. Shutdown occurs when
464        // `Shutdown` (`TripWire`) resolves. This can occur directly through a
465        // notification or indirectly through an external signal which, when
466        // received, results in triggering the notify.
467        let shutdown = self.shutdown();
468        let sig_stream = self.config.shutdown.signal_stream();
469        let grace = self.config.shutdown.grace as u64;
470        let mercy = self.config.shutdown.mercy as u64;
471
472        // Start a task that listens for external signals and notifies shutdown.
473        if let Some(mut stream) = sig_stream {
474            let shutdown = shutdown.clone();
475            tokio::spawn(async move {
476                while let Some(sig) = stream.next().await {
477                    if shutdown.0.tripped() {
478                        warn!("Received {}. Shutdown already in progress.", sig);
479                    } else {
480                        warn!("Received {}. Requesting shutdown.", sig);
481                    }
482
483                    shutdown.0.trip();
484                }
485            });
486        }
487
488        // Save the keep-alive value for later use; we're about to move `self`.
489        let keep_alive = self.config.keep_alive;
490
491        // Create the Hyper `Service`.
492        let rocket = Arc::new(self);
493        let service_fn = |conn: &CancellableIo<_, L::Connection>| {
494            let rocket = rocket.clone();
495            let connection = ConnectionMeta {
496                remote: conn.peer_address(),
497                client_certificates: conn.peer_certificates(),
498            };
499
500            async move {
501                Ok::<_, std::convert::Infallible>(hyper::service::service_fn(move |req| {
502                    hyper_service_fn(rocket.clone(), connection.clone(), req)
503                }))
504            }
505        };
506
507        // NOTE: `hyper` uses `tokio::spawn()` as the default executor.
508        let listener = CancellableListener::new(shutdown.clone(), listener, grace, mercy);
509        let builder = hyper::server::Server::builder(Incoming::new(listener).nodelay(true));
510
511        #[cfg(feature = "http2")]
512        let builder = builder.http2_keep_alive_interval(match keep_alive {
513            0 => None,
514            n => Some(Duration::from_secs(n as u64))
515        });
516
517        let server = builder
518            .http1_keepalive(keep_alive != 0)
519            .http1_preserve_header_case(true)
520            .serve(hyper::service::make_service_fn(service_fn))
521            .with_graceful_shutdown(shutdown.clone());
522
523        // This deserves some explanation.
524        //
525        // This is largely to deal with Hyper's dreadful and largely nonexistent
526        // handling of shutdown, in general, nevermind graceful.
527        //
528        // When Hyper receives a "graceful shutdown" request, it stops accepting
529        // new requests. That's it. It continues to process existing requests
530        // and outgoing responses forever and never cancels them. As a result,
531        // Rocket must take it upon itself to cancel any existing I/O.
532        //
533        // To do so, Rocket wraps all connections in a `CancellableIo` struct,
534        // an internal structure that gracefully closes I/O when it receives a
535        // signal. That signal is the `shutdown` future. When the future
536        // resolves, `CancellableIo` begins to terminate in grace, mercy, and
537        // finally force close phases. Since all connections are wrapped in
538        // `CancellableIo`, this eventually ends all I/O.
539        //
540        // At that point, unless a user spawned an infinite, stand-alone task
541        // that isn't monitoring `Shutdown`, all tasks should resolve. This
542        // means that all instances of the shared `Arc<Rocket>` are dropped and
543        // we can return the owned instance of `Rocket`.
544        //
545        // Unfortunately, the Hyper `server` future resolves as soon as it has
546        // finishes processing requests without respect for ongoing responses.
547        // That is, `server` resolves even when there are running tasks that are
548        // generating a response. So, `server` resolving implies little to
549        // nothing about the state of connections. As a result, we depend on the
550        // timing of grace + mercy + some buffer to determine when all
551        // connections should be closed, thus all tasks should be complete, thus
552        // all references to `Arc<Rocket>` should be dropped and we can get a
553        // unique reference.
554        tokio::pin!(server);
555        tokio::select! {
556            biased;
557
558            _ = shutdown => {
559                // Run shutdown fairings. We compute `sleep()` for grace periods
560                // beforehand to ensure we don't add shutdown fairing completion
561                // time, which is arbitrary, to these periods.
562                info!("Shutdown requested. Waiting for pending I/O...");
563                let grace_timer = sleep(Duration::from_secs(grace));
564                let mercy_timer = sleep(Duration::from_secs(grace + mercy));
565                let shutdown_timer = sleep(Duration::from_secs(grace + mercy + 1));
566                rocket.fairings.handle_shutdown(&*rocket).await;
567
568                tokio::pin!(grace_timer, mercy_timer, shutdown_timer);
569                tokio::select! {
570                    biased;
571
572                    result = &mut server => {
573                        if let Err(e) = result {
574                            warn!("Server failed while shutting down: {}", e);
575                            return Err(Error::shutdown(rocket.clone(), e));
576                        }
577
578                        if Arc::strong_count(&rocket) != 1 { grace_timer.await; }
579                        if Arc::strong_count(&rocket) != 1 { mercy_timer.await; }
580                        if Arc::strong_count(&rocket) != 1 { shutdown_timer.await; }
581                        match Arc::try_unwrap(rocket) {
582                            Ok(rocket) => {
583                                info!("Graceful shutdown completed successfully.");
584                                Ok(rocket)
585                            }
586                            Err(rocket) => {
587                                warn!("Shutdown failed: outstanding background I/O.");
588                                Err(Error::shutdown(rocket, None))
589                            }
590                        }
591                    }
592                    _ = &mut shutdown_timer => {
593                        warn!("Shutdown failed: server executing after timeouts.");
594                        return Err(Error::shutdown(rocket.clone(), None));
595                    },
596                }
597            }
598            result = &mut server => {
599                match result {
600                    Ok(()) => {
601                        info!("Server shutdown nominally.");
602                        Ok(Arc::try_unwrap(rocket).map_err(|r| Error::shutdown(r, None))?)
603                    }
604                    Err(e) => {
605                        info!("Server failed prior to shutdown: {}:", e);
606                        Err(Error::shutdown(rocket.clone(), e))
607                    }
608                }
609            }
610        }
611    }
612}