rocket/tls/
resolver.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::ops::Deref;
4use std::sync::Arc;
5
6pub use rustls::server::{ClientHello, ServerConfig};
7
8use crate::{Build, Ignite, Rocket};
9use crate::fairing::{self, Info, Kind};
10
11/// Proxy type to get PartialEq + Debug impls.
12#[derive(Clone)]
13pub(crate) struct DynResolver(Arc<dyn Resolver>);
14
15pub struct Fairing<T: ?Sized>(PhantomData<T>);
16
17/// A dynamic TLS configuration resolver.
18///
19/// # Example
20///
21/// This is an async trait. Implement it as follows:
22///
23/// ```rust
24/// # #[macro_use] extern crate rocket;
25/// use std::sync::Arc;
26/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
27/// use rocket::{Rocket, Build};
28///
29/// struct MyResolver(Arc<ServerConfig>);
30///
31/// #[rocket::async_trait]
32/// impl Resolver for MyResolver {
33///     async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
34///         // This is equivalent to what the default resolver would do.
35///         let config: TlsConfig = rocket.figment().extract_inner("tls")?;
36///         let server_config = config.server_config().await?;
37///         Ok(MyResolver(Arc::new(server_config)))
38///     }
39///
40///     async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
41///         // return a `ServerConfig` based on `hello`; here we ignore it
42///         Some(self.0.clone())
43///     }
44/// }
45///
46/// #[launch]
47/// fn rocket() -> _ {
48///     rocket::build().attach(MyResolver::fairing())
49/// }
50/// ```
51#[crate::async_trait]
52pub trait Resolver: Send + Sync + 'static {
53    async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
54        let _rocket = rocket;
55        let type_name = std::any::type_name::<Self>();
56        Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
57    }
58
59    async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
60
61    fn fairing() -> Fairing<Self> where Self: Sized {
62        Fairing(PhantomData)
63    }
64}
65
66#[crate::async_trait]
67impl<T: Resolver> fairing::Fairing for Fairing<T> {
68    fn info(&self) -> Info {
69        Info {
70            name: "Resolver Fairing",
71            kind: Kind::Ignite | Kind::Singleton
72        }
73    }
74
75    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
76        let result = T::init(&rocket).await;
77        match result {
78            Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
79            Err(e) => {
80                let type_name = std::any::type_name::<T>();
81                error!(type_name, reason = %e, "TLS resolver failed to initialize");
82                Err(rocket)
83            }
84        }
85    }
86}
87
88impl DynResolver {
89    pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
90        rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
91    }
92}
93
94impl fmt::Debug for DynResolver {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        f.debug_tuple("Resolver").finish()
97    }
98}
99
100impl PartialEq for DynResolver {
101    fn eq(&self, _: &Self) -> bool {
102        false
103    }
104}
105
106impl Deref for DynResolver {
107    type Target = dyn Resolver;
108
109    fn deref(&self) -> &Self::Target {
110        &*self.0
111    }
112}