rocket/tls/
listener.rs
1use std::io;
2use std::sync::Arc;
3
4use futures::TryFutureExt;
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio_rustls::LazyConfigAcceptor;
7use rustls::server::{Acceptor, ServerConfig};
8
9use crate::{Ignite, Rocket};
10use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
11use crate::tls::{TlsConfig, Result, Error};
12use super::resolver::DynResolver;
13
14#[doc(inline)]
15pub use tokio_rustls::server::TlsStream;
16
17pub struct TlsListener<L> {
19 listener: L,
20 config: TlsConfig,
21 default: Arc<ServerConfig>,
22}
23
24impl<L> TlsListener<L>
25 where L: Listener<Accept = <L as Listener>::Connection>,
26{
27 pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
28 Ok(TlsListener {
29 default: Arc::new(config.server_config().await?),
30 listener,
31 config,
32 })
33 }
34}
35
36impl<L: Bind> Bind for TlsListener<L>
37 where L: Listener<Accept = <L as Listener>::Connection>
38{
39 type Error = Error;
40
41 async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
42 let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
43 let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
44 config.resolver = DynResolver::extract(rocket);
45 Self::from(listener, config).await
46 }
47
48 fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
49 let config: TlsConfig = rocket.figment().extract_inner("tls")?;
50 L::bind_endpoint(rocket)
51 .map(|e| e.with_tls(&config))
52 .map_err(|e| Error::Bind(Box::new(e)))
53 }
54}
55
56impl<L> Listener for TlsListener<L>
57 where L: Listener<Accept = <L as Listener>::Connection>,
58 L::Connection: AsyncRead + AsyncWrite
59{
60 type Accept = L::Connection;
61
62 type Connection = TlsStream<L::Connection>;
63
64 async fn accept(&self) -> io::Result<Self::Accept> {
65 self.listener.accept().await
66 }
67
68 async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
69 let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
70 let handshake = acceptor.await?;
71 let hello = handshake.client_hello();
72 let config = match &self.config.resolver {
73 Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
74 None => self.default.clone(),
75 };
76
77 handshake.into_stream(config).await
78 }
79
80 fn endpoint(&self) -> io::Result<Endpoint> {
81 Ok(self.listener.endpoint()?.with_tls(&self.config))
82 }
83}
84
85impl<C: Connection> Connection for TlsStream<C> {
86 fn endpoint(&self) -> io::Result<Endpoint> {
87 Ok(self.get_ref().0.endpoint()?.assume_tls())
88 }
89
90 fn certificates(&self) -> Option<Certificates<'_>> {
91 #[cfg(feature = "mtls")] {
92 let cert_chain = self.get_ref().1.peer_certificates()?;
93 Some(Certificates::from(cert_chain))
94 }
95
96 #[cfg(not(feature = "mtls"))]
97 None
98 }
99}