1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;
pub use rustls::server::{ClientHello, ServerConfig};
use crate::{Build, Ignite, Rocket};
use crate::fairing::{self, Info, Kind};
/// Proxy type to get PartialEq + Debug impls.
#[derive(Clone)]
pub(crate) struct DynResolver(Arc<dyn Resolver>);
pub struct Fairing<T: ?Sized>(PhantomData<T>);
/// A dynamic TLS configuration resolver.
///
/// # Example
///
/// This is an async trait. Implement it as follows:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use std::sync::Arc;
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
/// use rocket::{Rocket, Build};
///
/// struct MyResolver(Arc<ServerConfig>);
///
/// #[rocket::async_trait]
/// impl Resolver for MyResolver {
/// async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
/// // This is equivalent to what the default resolver would do.
/// let config: TlsConfig = rocket.figment().extract_inner("tls")?;
/// let server_config = config.server_config().await?;
/// Ok(MyResolver(Arc::new(server_config)))
/// }
///
/// async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
/// // return a `ServerConfig` based on `hello`; here we ignore it
/// Some(self.0.clone())
/// }
/// }
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().attach(MyResolver::fairing())
/// }
/// ```
#[crate::async_trait]
pub trait Resolver: Send + Sync + 'static {
async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
let _rocket = rocket;
let type_name = std::any::type_name::<Self>();
Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
}
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
fn fairing() -> Fairing<Self> where Self: Sized {
Fairing(PhantomData)
}
}
#[crate::async_trait]
impl<T: Resolver> fairing::Fairing for Fairing<T> {
fn info(&self) -> Info {
Info {
name: "Resolver Fairing",
kind: Kind::Ignite | Kind::Singleton
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
let result = T::init(&rocket).await;
match result {
Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
Err(e) => {
let type_name = std::any::type_name::<T>();
error!(type_name, reason = %e, "TLS resolver failed to initialize");
Err(rocket)
}
}
}
}
impl DynResolver {
pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
}
}
impl fmt::Debug for DynResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Resolver").finish()
}
}
impl PartialEq for DynResolver {
fn eq(&self, _: &Self) -> bool {
false
}
}
impl Deref for DynResolver {
type Target = dyn Resolver;
fn deref(&self) -> &Self::Target {
&*self.0
}
}