1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;

pub use rustls::server::{ClientHello, ServerConfig};

use crate::{Build, Ignite, Rocket};
use crate::fairing::{self, Info, Kind};

/// Proxy type to get PartialEq + Debug impls.
#[derive(Clone)]
pub(crate) struct DynResolver(Arc<dyn Resolver>);

pub struct Fairing<T: ?Sized>(PhantomData<T>);

/// A dynamic TLS configuration resolver.
///
/// # Example
///
/// This is an async trait. Implement it as follows:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use std::sync::Arc;
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
/// use rocket::{Rocket, Build};
///
/// struct MyResolver(Arc<ServerConfig>);
///
/// #[rocket::async_trait]
/// impl Resolver for MyResolver {
///     async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
///         // This is equivalent to what the default resolver would do.
///         let config: TlsConfig = rocket.figment().extract_inner("tls")?;
///         let server_config = config.server_config().await?;
///         Ok(MyResolver(Arc::new(server_config)))
///     }
///
///     async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
///         // return a `ServerConfig` based on `hello`; here we ignore it
///         Some(self.0.clone())
///     }
/// }
///
/// #[launch]
/// fn rocket() -> _ {
///     rocket::build().attach(MyResolver::fairing())
/// }
/// ```
#[crate::async_trait]
pub trait Resolver: Send + Sync + 'static {
    async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
        let _rocket = rocket;
        let type_name = std::any::type_name::<Self>();
        Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
    }

    async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;

    fn fairing() -> Fairing<Self> where Self: Sized {
        Fairing(PhantomData)
    }
}

#[crate::async_trait]
impl<T: Resolver> fairing::Fairing for Fairing<T> {
    fn info(&self) -> Info {
        Info {
            name: "Resolver Fairing",
            kind: Kind::Ignite | Kind::Singleton
        }
    }

    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
        let result = T::init(&rocket).await;
        match result {
            Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
            Err(e) => {
                let type_name = std::any::type_name::<T>();
                error!(type_name, reason = %e, "TLS resolver failed to initialize");
                Err(rocket)
            }
        }
    }
}

impl DynResolver {
    pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
        rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
    }
}

impl fmt::Debug for DynResolver {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("Resolver").finish()
    }
}

impl PartialEq for DynResolver {
    fn eq(&self, _: &Self) -> bool {
        false
    }
}

impl Deref for DynResolver {
    type Target = dyn Resolver;

    fn deref(&self) -> &Self::Target {
        &*self.0
    }
}