rocket/listener/
endpoint.rs

1use std::fmt;
2use std::any::Any;
3use std::net::{self, AddrParseError, IpAddr, Ipv4Addr};
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::sync::Arc;
7
8use figment::Figment;
9use serde::de;
10
11use crate::http::uncased::AsUncased;
12
13#[cfg(feature = "tls")]      type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
14#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
15
16pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { }
17
18impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> CustomEndpoint for T {}
19
20/// # Conversions
21///
22/// * [`&str`] - parse with [`FromStr`]
23/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
24/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
25///
26/// # Syntax
27///
28/// The string syntax is:
29///
30/// ```text
31/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
32/// socket := IP_ADDR | SOCKET_ADDR
33/// path := PATH
34///
35/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
36/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
37/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
38/// ```
39///
40/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
41#[derive(Clone)]
42#[non_exhaustive]
43pub enum Endpoint {
44    Tcp(net::SocketAddr),
45    Quic(net::SocketAddr),
46    Unix(PathBuf),
47    Tls(Arc<Endpoint>, TlsInfo),
48    Custom(Arc<dyn CustomEndpoint>),
49}
50
51impl Endpoint {
52    pub fn new<T: CustomEndpoint>(value: T) -> Endpoint {
53        Endpoint::Custom(Arc::new(value))
54    }
55
56    pub fn tcp(&self) -> Option<net::SocketAddr> {
57        match self {
58            Endpoint::Tcp(addr) => Some(*addr),
59            Endpoint::Tls(addr, _) => addr.tcp(),
60            _ => None,
61        }
62    }
63
64    pub fn quic(&self) -> Option<net::SocketAddr> {
65        match self {
66            Endpoint::Quic(addr) => Some(*addr),
67            Endpoint::Tls(addr, _) => addr.tcp(),
68            _ => None,
69        }
70    }
71
72    pub fn socket_addr(&self) -> Option<net::SocketAddr> {
73        match self {
74            Endpoint::Quic(addr) => Some(*addr),
75            Endpoint::Tcp(addr) => Some(*addr),
76            Endpoint::Tls(inner, _) => inner.socket_addr(),
77            _ => None,
78        }
79    }
80
81    pub fn ip(&self) -> Option<IpAddr> {
82        match self {
83            Endpoint::Quic(addr) => Some(addr.ip()),
84            Endpoint::Tcp(addr) => Some(addr.ip()),
85            Endpoint::Tls(inner, _) => inner.ip(),
86            _ => None,
87        }
88    }
89
90    pub fn port(&self) -> Option<u16> {
91        match self {
92            Endpoint::Quic(addr) => Some(addr.port()),
93            Endpoint::Tcp(addr) => Some(addr.port()),
94            Endpoint::Tls(inner, _) => inner.port(),
95            _ => None,
96        }
97    }
98
99    pub fn unix(&self) -> Option<&Path> {
100        match self {
101            Endpoint::Unix(addr) => Some(addr),
102            Endpoint::Tls(addr, _) => addr.unix(),
103            _ => None,
104        }
105    }
106
107    pub fn tls(&self) -> Option<&Endpoint> {
108        match self {
109            Endpoint::Tls(addr, _) => Some(addr),
110            _ => None,
111        }
112    }
113
114    #[cfg(feature = "tls")]
115    pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> {
116        match self {
117            Endpoint::Tls(_, Some(ref config)) => Some(config),
118            _ => None,
119        }
120    }
121
122    #[cfg(feature = "mtls")]
123    pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> {
124        match self {
125            Endpoint::Tls(_, Some(config)) => config.mutual(),
126            _ => None,
127        }
128    }
129
130    pub fn downcast<T: 'static>(&self) -> Option<&T> {
131        match self {
132            Endpoint::Tcp(addr) => (addr as &dyn Any).downcast_ref(),
133            Endpoint::Quic(addr) => (addr as &dyn Any).downcast_ref(),
134            Endpoint::Unix(addr) => (addr as &dyn Any).downcast_ref(),
135            Endpoint::Custom(addr) => (addr as &dyn Any).downcast_ref(),
136            Endpoint::Tls(inner, ..) => inner.downcast(),
137        }
138    }
139
140    pub fn is_tcp(&self) -> bool {
141        self.tcp().is_some()
142    }
143
144    pub fn is_quic(&self) -> bool {
145        self.quic().is_some()
146    }
147
148    pub fn is_unix(&self) -> bool {
149        self.unix().is_some()
150    }
151
152    pub fn is_tls(&self) -> bool {
153        self.tls().is_some()
154    }
155
156    #[cfg(feature = "tls")]
157    pub fn with_tls(self, tls: &crate::tls::TlsConfig) -> Endpoint {
158        if self.is_tls() {
159            return self;
160        }
161
162        Self::Tls(Arc::new(self), Some(Box::new(tls.clone())))
163    }
164
165    pub fn assume_tls(self) -> Endpoint {
166        if self.is_tls() {
167            return self;
168        }
169
170        Self::Tls(Arc::new(self), None)
171    }
172
173    /// Fetch the endpoint at `path` in `figment` of kind `kind` (e.g, "tcp")
174    /// then map the value using `f(Some(value))` if present and `f(None)` if
175    /// missing into a different value of typr `T`.
176    ///
177    /// If the conversion succeeds, returns `Ok(value)`. If the conversion fails
178    /// and `Some` value was passed in, returns an error indicating the endpoint
179    /// was an invalid `kind` and otherwise returns a "missing field" error.
180    pub(crate) fn fetch<T, F>(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result<T>
181        where F: FnOnce(Option<&Endpoint>) -> Option<T>
182    {
183        match figment.extract_inner::<Endpoint>(path) {
184            Ok(endpoint) => f(Some(&endpoint)).ok_or_else(|| {
185                let msg = format!("invalid {kind} endpoint: {endpoint:?}");
186                let mut error = figment::Error::from(msg).with_path(path);
187                error.profile = Some(figment.profile().clone());
188                error.metadata = figment.find_metadata(path).cloned();
189                error
190            }),
191            Err(e) if e.missing() => f(None).ok_or(e),
192            Err(e) => Err(e)
193        }
194    }
195}
196
197impl fmt::Display for Endpoint {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        use Endpoint::*;
200
201        match self {
202            Tcp(addr) | Quic(addr) => write!(f, "http://{addr}"),
203            Unix(addr) => write!(f, "unix:{}", addr.display()),
204            Custom(inner) => inner.fmt(f),
205            Tls(inner, _c) => {
206                match (inner.tcp(), inner.quic()) {
207                    (Some(addr), _) => write!(f, "https://{addr} (TCP")?,
208                    (_, Some(addr)) => write!(f, "https://{addr} (QUIC")?,
209                    (None, None) => write!(f, "{inner} (TLS")?,
210                }
211
212                #[cfg(feature = "mtls")]
213                if _c.as_ref().and_then(|c| c.mutual()).is_some() {
214                    write!(f, " + mTLS")?;
215                }
216
217                write!(f, ")")
218            }
219        }
220    }
221}
222
223impl fmt::Debug for Endpoint {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        match self {
226            Self::Tcp(a) => write!(f, "tcp:{a}"),
227            Self::Quic(a) => write!(f, "quic:{a}]"),
228            Self::Unix(a) => write!(f, "unix:{}", a.display()),
229            Self::Tls(e, _) => write!(f, "unix:{:?}", &**e),
230            Self::Custom(e) => e.fmt(f),
231        }
232    }
233}
234
235impl Default for Endpoint {
236    fn default() -> Self {
237        Endpoint::Tcp(net::SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
238    }
239}
240
241impl FromStr for Endpoint {
242    type Err = AddrParseError;
243
244    fn from_str(string: &str) -> Result<Self, Self::Err> {
245        fn parse_tcp(str: &str, def_port: u16) -> Result<net::SocketAddr, AddrParseError> {
246            str.parse().or_else(|_| str.parse().map(|ip| net::SocketAddr::new(ip, def_port)))
247        }
248
249        if let Some((proto, string)) = string.split_once(':') {
250            if proto.trim().as_uncased() == "tcp" {
251                return parse_tcp(string.trim(), 8000).map(Self::Tcp);
252            } else if proto.trim().as_uncased() == "unix" {
253                return Ok(Self::Unix(PathBuf::from(string.trim())));
254            }
255        }
256
257        parse_tcp(string.trim(), 8000).map(Self::Tcp)
258    }
259}
260
261impl<'de> de::Deserialize<'de> for Endpoint {
262    fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
263        struct Visitor;
264
265        impl<'de> de::Visitor<'de> for Visitor {
266            type Value = Endpoint;
267
268            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
269                formatter.write_str("valid TCP (ip) or unix (path) endpoint")
270            }
271
272            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
273                v.parse::<Endpoint>().map_err(|e| E::custom(e.to_string()))
274            }
275        }
276
277        de.deserialize_any(Visitor)
278    }
279}
280
281impl Eq for Endpoint { }
282
283impl PartialEq for Endpoint {
284    fn eq(&self, other: &Self) -> bool {
285        match (self, other) {
286            (Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
287            (Self::Quic(l0), Self::Quic(r0)) => l0 == r0,
288            (Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
289            (Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
290            (Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
291            _ => false,
292        }
293    }
294}
295
296impl PartialEq<PathBuf> for Endpoint {
297    fn eq(&self, other: &PathBuf) -> bool {
298        self.unix() == Some(other.as_path())
299    }
300}
301
302impl PartialEq<Path> for Endpoint {
303    fn eq(&self, other: &Path) -> bool {
304        self.unix() == Some(other)
305    }
306}
307
308#[cfg(unix)]
309impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
310    type Error = std::io::Error;
311
312    fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
313        v.as_pathname()
314            .ok_or_else(|| std::io::Error::other("unix socket is not path"))
315            .map(|path| Endpoint::Unix(path.to_path_buf()))
316    }
317}
318
319impl TryFrom<&str> for Endpoint {
320    type Error = AddrParseError;
321
322    fn try_from(value: &str) -> Result<Self, Self::Error> {
323        value.parse()
324    }
325}
326
327macro_rules! impl_from {
328    ($T:ty => $V:ident) => {
329        impl From<$T> for Endpoint {
330            fn from(value: $T) -> Self {
331                Self::$V(value.into())
332            }
333        }
334    }
335}
336
337impl_from!(std::net::SocketAddr => Tcp);
338impl_from!(std::net::SocketAddrV4 => Tcp);
339impl_from!(std::net::SocketAddrV6 => Tcp);
340impl_from!(PathBuf => Unix);