1use std::fmt;
2use std::any::Any;
3use std::net::{self, AddrParseError, IpAddr, Ipv4Addr};
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::sync::Arc;
7
8use figment::Figment;
9use serde::de;
10
11use crate::http::uncased::AsUncased;
12
13#[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
14#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
15
16pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { }
17
18impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> CustomEndpoint for T {}
19
20#[derive(Clone)]
42#[non_exhaustive]
43pub enum Endpoint {
44 Tcp(net::SocketAddr),
45 Quic(net::SocketAddr),
46 Unix(PathBuf),
47 Tls(Arc<Endpoint>, TlsInfo),
48 Custom(Arc<dyn CustomEndpoint>),
49}
50
51impl Endpoint {
52 pub fn new<T: CustomEndpoint>(value: T) -> Endpoint {
53 Endpoint::Custom(Arc::new(value))
54 }
55
56 pub fn tcp(&self) -> Option<net::SocketAddr> {
57 match self {
58 Endpoint::Tcp(addr) => Some(*addr),
59 Endpoint::Tls(addr, _) => addr.tcp(),
60 _ => None,
61 }
62 }
63
64 pub fn quic(&self) -> Option<net::SocketAddr> {
65 match self {
66 Endpoint::Quic(addr) => Some(*addr),
67 Endpoint::Tls(addr, _) => addr.tcp(),
68 _ => None,
69 }
70 }
71
72 pub fn socket_addr(&self) -> Option<net::SocketAddr> {
73 match self {
74 Endpoint::Quic(addr) => Some(*addr),
75 Endpoint::Tcp(addr) => Some(*addr),
76 Endpoint::Tls(inner, _) => inner.socket_addr(),
77 _ => None,
78 }
79 }
80
81 pub fn ip(&self) -> Option<IpAddr> {
82 match self {
83 Endpoint::Quic(addr) => Some(addr.ip()),
84 Endpoint::Tcp(addr) => Some(addr.ip()),
85 Endpoint::Tls(inner, _) => inner.ip(),
86 _ => None,
87 }
88 }
89
90 pub fn port(&self) -> Option<u16> {
91 match self {
92 Endpoint::Quic(addr) => Some(addr.port()),
93 Endpoint::Tcp(addr) => Some(addr.port()),
94 Endpoint::Tls(inner, _) => inner.port(),
95 _ => None,
96 }
97 }
98
99 pub fn unix(&self) -> Option<&Path> {
100 match self {
101 Endpoint::Unix(addr) => Some(addr),
102 Endpoint::Tls(addr, _) => addr.unix(),
103 _ => None,
104 }
105 }
106
107 pub fn tls(&self) -> Option<&Endpoint> {
108 match self {
109 Endpoint::Tls(addr, _) => Some(addr),
110 _ => None,
111 }
112 }
113
114 #[cfg(feature = "tls")]
115 pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> {
116 match self {
117 Endpoint::Tls(_, Some(ref config)) => Some(config),
118 _ => None,
119 }
120 }
121
122 #[cfg(feature = "mtls")]
123 pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> {
124 match self {
125 Endpoint::Tls(_, Some(config)) => config.mutual(),
126 _ => None,
127 }
128 }
129
130 pub fn downcast<T: 'static>(&self) -> Option<&T> {
131 match self {
132 Endpoint::Tcp(addr) => (addr as &dyn Any).downcast_ref(),
133 Endpoint::Quic(addr) => (addr as &dyn Any).downcast_ref(),
134 Endpoint::Unix(addr) => (addr as &dyn Any).downcast_ref(),
135 Endpoint::Custom(addr) => (addr as &dyn Any).downcast_ref(),
136 Endpoint::Tls(inner, ..) => inner.downcast(),
137 }
138 }
139
140 pub fn is_tcp(&self) -> bool {
141 self.tcp().is_some()
142 }
143
144 pub fn is_quic(&self) -> bool {
145 self.quic().is_some()
146 }
147
148 pub fn is_unix(&self) -> bool {
149 self.unix().is_some()
150 }
151
152 pub fn is_tls(&self) -> bool {
153 self.tls().is_some()
154 }
155
156 #[cfg(feature = "tls")]
157 pub fn with_tls(self, tls: &crate::tls::TlsConfig) -> Endpoint {
158 if self.is_tls() {
159 return self;
160 }
161
162 Self::Tls(Arc::new(self), Some(Box::new(tls.clone())))
163 }
164
165 pub fn assume_tls(self) -> Endpoint {
166 if self.is_tls() {
167 return self;
168 }
169
170 Self::Tls(Arc::new(self), None)
171 }
172
173 pub(crate) fn fetch<T, F>(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result<T>
181 where F: FnOnce(Option<&Endpoint>) -> Option<T>
182 {
183 match figment.extract_inner::<Endpoint>(path) {
184 Ok(endpoint) => f(Some(&endpoint)).ok_or_else(|| {
185 let msg = format!("invalid {kind} endpoint: {endpoint:?}");
186 let mut error = figment::Error::from(msg).with_path(path);
187 error.profile = Some(figment.profile().clone());
188 error.metadata = figment.find_metadata(path).cloned();
189 error
190 }),
191 Err(e) if e.missing() => f(None).ok_or(e),
192 Err(e) => Err(e)
193 }
194 }
195}
196
197impl fmt::Display for Endpoint {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 use Endpoint::*;
200
201 match self {
202 Tcp(addr) | Quic(addr) => write!(f, "http://{addr}"),
203 Unix(addr) => write!(f, "unix:{}", addr.display()),
204 Custom(inner) => inner.fmt(f),
205 Tls(inner, _c) => {
206 match (inner.tcp(), inner.quic()) {
207 (Some(addr), _) => write!(f, "https://{addr} (TCP")?,
208 (_, Some(addr)) => write!(f, "https://{addr} (QUIC")?,
209 (None, None) => write!(f, "{inner} (TLS")?,
210 }
211
212 #[cfg(feature = "mtls")]
213 if _c.as_ref().and_then(|c| c.mutual()).is_some() {
214 write!(f, " + mTLS")?;
215 }
216
217 write!(f, ")")
218 }
219 }
220 }
221}
222
223impl fmt::Debug for Endpoint {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 match self {
226 Self::Tcp(a) => write!(f, "tcp:{a}"),
227 Self::Quic(a) => write!(f, "quic:{a}]"),
228 Self::Unix(a) => write!(f, "unix:{}", a.display()),
229 Self::Tls(e, _) => write!(f, "unix:{:?}", &**e),
230 Self::Custom(e) => e.fmt(f),
231 }
232 }
233}
234
235impl Default for Endpoint {
236 fn default() -> Self {
237 Endpoint::Tcp(net::SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
238 }
239}
240
241impl FromStr for Endpoint {
242 type Err = AddrParseError;
243
244 fn from_str(string: &str) -> Result<Self, Self::Err> {
245 fn parse_tcp(str: &str, def_port: u16) -> Result<net::SocketAddr, AddrParseError> {
246 str.parse().or_else(|_| str.parse().map(|ip| net::SocketAddr::new(ip, def_port)))
247 }
248
249 if let Some((proto, string)) = string.split_once(':') {
250 if proto.trim().as_uncased() == "tcp" {
251 return parse_tcp(string.trim(), 8000).map(Self::Tcp);
252 } else if proto.trim().as_uncased() == "unix" {
253 return Ok(Self::Unix(PathBuf::from(string.trim())));
254 }
255 }
256
257 parse_tcp(string.trim(), 8000).map(Self::Tcp)
258 }
259}
260
261impl<'de> de::Deserialize<'de> for Endpoint {
262 fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
263 struct Visitor;
264
265 impl<'de> de::Visitor<'de> for Visitor {
266 type Value = Endpoint;
267
268 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
269 formatter.write_str("valid TCP (ip) or unix (path) endpoint")
270 }
271
272 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
273 v.parse::<Endpoint>().map_err(|e| E::custom(e.to_string()))
274 }
275 }
276
277 de.deserialize_any(Visitor)
278 }
279}
280
281impl Eq for Endpoint { }
282
283impl PartialEq for Endpoint {
284 fn eq(&self, other: &Self) -> bool {
285 match (self, other) {
286 (Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
287 (Self::Quic(l0), Self::Quic(r0)) => l0 == r0,
288 (Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
289 (Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
290 (Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
291 _ => false,
292 }
293 }
294}
295
296impl PartialEq<PathBuf> for Endpoint {
297 fn eq(&self, other: &PathBuf) -> bool {
298 self.unix() == Some(other.as_path())
299 }
300}
301
302impl PartialEq<Path> for Endpoint {
303 fn eq(&self, other: &Path) -> bool {
304 self.unix() == Some(other)
305 }
306}
307
308#[cfg(unix)]
309impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
310 type Error = std::io::Error;
311
312 fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
313 v.as_pathname()
314 .ok_or_else(|| std::io::Error::other("unix socket is not path"))
315 .map(|path| Endpoint::Unix(path.to_path_buf()))
316 }
317}
318
319impl TryFrom<&str> for Endpoint {
320 type Error = AddrParseError;
321
322 fn try_from(value: &str) -> Result<Self, Self::Error> {
323 value.parse()
324 }
325}
326
327macro_rules! impl_from {
328 ($T:ty => $V:ident) => {
329 impl From<$T> for Endpoint {
330 fn from(value: $T) -> Self {
331 Self::$V(value.into())
332 }
333 }
334 }
335}
336
337impl_from!(std::net::SocketAddr => Tcp);
338impl_from!(std::net::SocketAddrV4 => Tcp);
339impl_from!(std::net::SocketAddrV6 => Tcp);
340impl_from!(PathBuf => Unix);