rocket/tls/
listener.rs

1use std::io;
2use std::sync::Arc;
3
4use futures::TryFutureExt;
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio_rustls::LazyConfigAcceptor;
7use rustls::server::{Acceptor, ServerConfig};
8
9use crate::{Ignite, Rocket};
10use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
11use crate::tls::{TlsConfig, Result, Error};
12use super::resolver::DynResolver;
13
14#[doc(inline)]
15pub use tokio_rustls::server::TlsStream;
16
17/// A TLS listener over some listener interface L.
18pub struct TlsListener<L> {
19    listener: L,
20    config: TlsConfig,
21    default: Arc<ServerConfig>,
22}
23
24impl<L> TlsListener<L>
25    where L: Listener<Accept = <L as Listener>::Connection>,
26{
27    pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
28        Ok(TlsListener {
29            default: Arc::new(config.server_config().await?),
30            listener,
31            config,
32        })
33    }
34}
35
36impl<L: Bind> Bind for TlsListener<L>
37    where L: Listener<Accept = <L as Listener>::Connection>
38{
39    type Error = Error;
40
41    async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
42        let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
43        let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
44        config.resolver = DynResolver::extract(rocket);
45        Self::from(listener, config).await
46    }
47
48    fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
49        let config: TlsConfig = rocket.figment().extract_inner("tls")?;
50        L::bind_endpoint(rocket)
51            .map(|e| e.with_tls(&config))
52            .map_err(|e| Error::Bind(Box::new(e)))
53    }
54}
55
56impl<L> Listener for TlsListener<L>
57    where L: Listener<Accept = <L as Listener>::Connection>,
58          L::Connection: AsyncRead + AsyncWrite
59{
60    type Accept = L::Connection;
61
62    type Connection = TlsStream<L::Connection>;
63
64    async fn accept(&self) -> io::Result<Self::Accept> {
65        self.listener.accept().await
66    }
67
68    async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
69        let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
70        let handshake = acceptor.await?;
71        let hello = handshake.client_hello();
72        let config = match &self.config.resolver {
73            Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
74            None => self.default.clone(),
75        };
76
77        handshake.into_stream(config).await
78    }
79
80    fn endpoint(&self) -> io::Result<Endpoint> {
81        Ok(self.listener.endpoint()?.with_tls(&self.config))
82    }
83}
84
85impl<C: Connection> Connection for TlsStream<C> {
86    fn endpoint(&self) -> io::Result<Endpoint> {
87        Ok(self.get_ref().0.endpoint()?.assume_tls())
88    }
89
90    fn certificates(&self) -> Option<Certificates<'_>> {
91        #[cfg(feature = "mtls")] {
92            let cert_chain = self.get_ref().1.peer_certificates()?;
93            Some(Certificates::from(cert_chain))
94        }
95
96        #[cfg(not(feature = "mtls"))]
97        None
98    }
99}