rocket/listener/
default.rs
1use core::fmt;
2
3use serde::Deserialize;
4use tokio_util::either::Either::{Left, Right};
5use either::Either;
6
7use crate::{Ignite, Rocket};
8use crate::listener::{Bind, Endpoint, tcp::TcpListener};
9
10#[cfg(unix)] use crate::listener::unix::UnixListener;
11#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};
12
13mod private {
14 use super::*;
15 use tokio_util::either::Either;
16
17 #[cfg(feature = "tls")] type TlsListener<T> = super::TlsListener<T>;
18 #[cfg(not(feature = "tls"))] type TlsListener<T> = T;
19 #[cfg(unix)] type UnixListener = super::UnixListener;
20 #[cfg(not(unix))] type UnixListener = TcpListener;
21
22 pub type Listener = Either<
23 Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
24 Either<TcpListener, UnixListener>,
25 >;
26
27 #[cfg(doc)]
59 pub struct DefaultListener(());
60}
61
62#[derive(Deserialize)]
63struct Config {
64 #[serde(default)]
65 address: Endpoint,
66 #[cfg(feature = "tls")]
67 tls: Option<TlsConfig>,
68}
69
70#[cfg(doc)]
71pub use private::DefaultListener;
72
73#[cfg(doc)]
74type Connection = crate::listener::tcp::TcpStream;
75
76#[cfg(doc)]
77impl Bind for DefaultListener {
78 type Error = Error;
79 async fn bind(_: &Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
80 fn bind_endpoint(_: &Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
81}
82
83#[cfg(doc)]
84impl super::Listener for DefaultListener {
85 #[doc(hidden)] type Accept = Connection;
86 #[doc(hidden)] type Connection = Connection;
87 #[doc(hidden)]
88 async fn accept(&self) -> std::io::Result<Connection> { unreachable!() }
89 #[doc(hidden)]
90 async fn connect(&self, _: Self::Accept) -> std::io::Result<Connection> { unreachable!() }
91 #[doc(hidden)]
92 fn endpoint(&self) -> std::io::Result<Endpoint> { unreachable!() }
93}
94
95#[cfg(not(doc))]
96pub type DefaultListener = private::Listener;
97
98#[cfg(not(doc))]
99impl Bind for DefaultListener {
100 type Error = Error;
101
102 async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
103 let config: Config = rocket.figment().extract()?;
104 match config.address {
105 #[cfg(feature = "tls")]
106 Endpoint::Tcp(_) if config.tls.is_some() => {
107 let listener = <TlsListener<TcpListener> as Bind>::bind(rocket).await?;
108 Ok(Left(Left(listener)))
109 }
110 Endpoint::Tcp(_) => {
111 let listener = <TcpListener as Bind>::bind(rocket).await?;
112 Ok(Right(Left(listener)))
113 }
114 #[cfg(all(unix, feature = "tls"))]
115 Endpoint::Unix(_) if config.tls.is_some() => {
116 let listener = <TlsListener<UnixListener> as Bind>::bind(rocket).await?;
117 Ok(Left(Right(listener)))
118 }
119 #[cfg(unix)]
120 Endpoint::Unix(_) => {
121 let listener = <UnixListener as Bind>::bind(rocket).await?;
122 Ok(Right(Right(listener)))
123 }
124 endpoint => Err(Error::Unsupported(endpoint)),
125 }
126 }
127
128 fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
129 let config: Config = rocket.figment().extract()?;
130 Ok(config.address)
131 }
132}
133
134#[derive(Debug)]
135pub enum Error {
136 Config(figment::Error),
137 Io(std::io::Error),
138 Unsupported(Endpoint),
139 #[cfg(feature = "tls")]
140 Tls(crate::tls::Error),
141}
142
143impl From<figment::Error> for Error {
144 fn from(value: figment::Error) -> Self {
145 Error::Config(value)
146 }
147}
148
149impl From<std::io::Error> for Error {
150 fn from(value: std::io::Error) -> Self {
151 Error::Io(value)
152 }
153}
154
155#[cfg(feature = "tls")]
156impl From<crate::tls::Error> for Error {
157 fn from(value: crate::tls::Error) -> Self {
158 Error::Tls(value)
159 }
160}
161
162impl From<Either<figment::Error, std::io::Error>> for Error {
163 fn from(value: Either<figment::Error, std::io::Error>) -> Self {
164 value.either(Error::Config, Error::Io)
165 }
166}
167
168impl fmt::Display for Error {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 match self {
171 Error::Config(e) => e.fmt(f),
172 Error::Io(e) => e.fmt(f),
173 Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
174 #[cfg(feature = "tls")]
175 Error::Tls(error) => error.fmt(f),
176 }
177 }
178}
179
180impl std::error::Error for Error {
181 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
182 match self {
183 Error::Config(e) => Some(e),
184 Error::Io(e) => Some(e),
185 Error::Unsupported(_) => None,
186 #[cfg(feature = "tls")]
187 Error::Tls(e) => Some(e),
188 }
189 }
190}