rocket/tls/
listener.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::io;
use std::sync::Arc;

use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::LazyConfigAcceptor;
use rustls::server::{Acceptor, ServerConfig};

use crate::{Ignite, Rocket};
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
use crate::tls::{TlsConfig, Result, Error};
use super::resolver::DynResolver;

#[doc(inline)]
pub use tokio_rustls::server::TlsStream;

/// A TLS listener over some listener interface L.
pub struct TlsListener<L> {
    listener: L,
    config: TlsConfig,
    default: Arc<ServerConfig>,
}

impl<L> TlsListener<L>
    where L: Listener<Accept = <L as Listener>::Connection>,
{
    pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
        Ok(TlsListener {
            default: Arc::new(config.server_config().await?),
            listener,
            config,
        })
    }
}

impl<L: Bind> Bind for TlsListener<L>
    where L: Listener<Accept = <L as Listener>::Connection>
{
    type Error = Error;

    async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
        let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
        let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
        config.resolver = DynResolver::extract(rocket);
        Self::from(listener, config).await
    }

    fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
        let config: TlsConfig = rocket.figment().extract_inner("tls")?;
        L::bind_endpoint(rocket)
            .map(|e| e.with_tls(&config))
            .map_err(|e| Error::Bind(Box::new(e)))
    }
}

impl<L> Listener for TlsListener<L>
    where L: Listener<Accept = <L as Listener>::Connection>,
          L::Connection: AsyncRead + AsyncWrite
{
    type Accept = L::Connection;

    type Connection = TlsStream<L::Connection>;

    async fn accept(&self) -> io::Result<Self::Accept> {
        self.listener.accept().await
    }

    async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
        let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
        let handshake = acceptor.await?;
        let hello = handshake.client_hello();
        let config = match &self.config.resolver {
            Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
            None => self.default.clone(),
        };

        handshake.into_stream(config).await
    }

    fn endpoint(&self) -> io::Result<Endpoint> {
        Ok(self.listener.endpoint()?.with_tls(&self.config))
    }
}

impl<C: Connection> Connection for TlsStream<C> {
    fn endpoint(&self) -> io::Result<Endpoint> {
        Ok(self.get_ref().0.endpoint()?.assume_tls())
    }

    fn certificates(&self) -> Option<Certificates<'_>> {
        #[cfg(feature = "mtls")] {
            let cert_chain = self.get_ref().1.peer_certificates()?;
            Some(Certificates::from(cert_chain))
        }

        #[cfg(not(feature = "mtls"))]
        None
    }
}