rocket/listener/
default.rs

1use core::fmt;
2
3use serde::Deserialize;
4use tokio_util::either::Either::{Left, Right};
5use either::Either;
6
7use crate::{Ignite, Rocket};
8use crate::listener::{Bind, Endpoint, tcp::TcpListener};
9
10#[cfg(unix)] use crate::listener::unix::UnixListener;
11#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};
12
13mod private {
14    use super::*;
15    use tokio_util::either::Either;
16
17    #[cfg(feature = "tls")] type TlsListener<T> = super::TlsListener<T>;
18    #[cfg(not(feature = "tls"))] type TlsListener<T> = T;
19    #[cfg(unix)] type UnixListener = super::UnixListener;
20    #[cfg(not(unix))] type UnixListener = TcpListener;
21
22    pub type Listener = Either<
23        Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
24        Either<TcpListener, UnixListener>,
25    >;
26
27    /// The default connection listener.
28    ///
29    /// # Configuration
30    ///
31    /// Reads the following optional configuration parameters:
32    ///
33    /// | parameter   | type              | default               |
34    /// | ----------- | ----------------- | --------------------- |
35    /// | `address`   | [`Endpoint`]      | `tcp:127.0.0.1:8000`  |
36    /// | `tls`       | [`TlsConfig`]     | None                  |
37    /// | `reuse`     | boolean           | `true`                |
38    ///
39    /// # Listener
40    ///
41    /// Based on the above configuration, this listener defers to one of the
42    /// following existing listeners:
43    ///
44    /// | listener                      | `address` type     | `tls` enabled |
45    /// |-------------------------------|--------------------|---------------|
46    /// | [`TcpListener`]               | [`Endpoint::Tcp`]  | no            |
47    /// | [`UnixListener`]              | [`Endpoint::Unix`] | no            |
48    /// | [`TlsListener<TcpListener>`]  | [`Endpoint::Tcp`]  | yes           |
49    /// | [`TlsListener<UnixListener>`] | [`Endpoint::Unix`] | yes           |
50    ///
51    /// [`UnixListener`]: crate::listener::unix::UnixListener
52    /// [`TlsListener<TcpListener>`]: crate::tls::TlsListener
53    /// [`TlsListener<UnixListener>`]: crate::tls::TlsListener
54    ///
55    ///  * **address type** is the variant the `address` parameter parses as.
56    ///  * **`tls` enabled** is `yes` when the `tls` feature is enabled _and_ a
57    ///    `tls` configuration is provided.
58    #[cfg(doc)]
59    pub struct DefaultListener(());
60}
61
62#[derive(Deserialize)]
63struct Config {
64    #[serde(default)]
65    address: Endpoint,
66    #[cfg(feature = "tls")]
67    tls: Option<TlsConfig>,
68}
69
70#[cfg(doc)]
71pub use private::DefaultListener;
72
73#[cfg(doc)]
74type Connection = crate::listener::tcp::TcpStream;
75
76#[cfg(doc)]
77impl Bind for DefaultListener {
78    type Error = Error;
79    async fn bind(_: &Rocket<Ignite>) -> Result<Self, Error>  { unreachable!() }
80    fn bind_endpoint(_: &Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
81}
82
83#[cfg(doc)]
84impl super::Listener for DefaultListener {
85    #[doc(hidden)] type Accept = Connection;
86    #[doc(hidden)] type Connection = Connection;
87    #[doc(hidden)]
88    async fn accept(&self) -> std::io::Result<Connection>  { unreachable!() }
89    #[doc(hidden)]
90    async fn connect(&self, _: Self::Accept) -> std::io::Result<Connection>  { unreachable!() }
91    #[doc(hidden)]
92    fn endpoint(&self) -> std::io::Result<Endpoint> { unreachable!() }
93}
94
95#[cfg(not(doc))]
96pub type DefaultListener = private::Listener;
97
98#[cfg(not(doc))]
99impl Bind for DefaultListener {
100    type Error = Error;
101
102    async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
103        let config: Config = rocket.figment().extract()?;
104        match config.address {
105            #[cfg(feature = "tls")]
106            Endpoint::Tcp(_) if config.tls.is_some() => {
107                let listener = <TlsListener<TcpListener> as Bind>::bind(rocket).await?;
108                Ok(Left(Left(listener)))
109            }
110            Endpoint::Tcp(_) => {
111                let listener = <TcpListener as Bind>::bind(rocket).await?;
112                Ok(Right(Left(listener)))
113            }
114            #[cfg(all(unix, feature = "tls"))]
115            Endpoint::Unix(_) if config.tls.is_some() => {
116                let listener = <TlsListener<UnixListener> as Bind>::bind(rocket).await?;
117                Ok(Left(Right(listener)))
118            }
119            #[cfg(unix)]
120            Endpoint::Unix(_) => {
121                let listener = <UnixListener as Bind>::bind(rocket).await?;
122                Ok(Right(Right(listener)))
123            }
124            endpoint => Err(Error::Unsupported(endpoint)),
125        }
126    }
127
128    fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
129        let config: Config = rocket.figment().extract()?;
130        Ok(config.address)
131    }
132}
133
134#[derive(Debug)]
135pub enum Error {
136    Config(figment::Error),
137    Io(std::io::Error),
138    Unsupported(Endpoint),
139    #[cfg(feature = "tls")]
140    Tls(crate::tls::Error),
141}
142
143impl From<figment::Error> for Error {
144    fn from(value: figment::Error) -> Self {
145        Error::Config(value)
146    }
147}
148
149impl From<std::io::Error> for Error {
150    fn from(value: std::io::Error) -> Self {
151        Error::Io(value)
152    }
153}
154
155#[cfg(feature = "tls")]
156impl From<crate::tls::Error> for Error {
157    fn from(value: crate::tls::Error) -> Self {
158        Error::Tls(value)
159    }
160}
161
162impl From<Either<figment::Error, std::io::Error>> for Error {
163    fn from(value: Either<figment::Error, std::io::Error>) -> Self {
164        value.either(Error::Config, Error::Io)
165    }
166}
167
168impl fmt::Display for Error {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        match self {
171            Error::Config(e) => e.fmt(f),
172            Error::Io(e) => e.fmt(f),
173            Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
174            #[cfg(feature = "tls")]
175            Error::Tls(error) => error.fmt(f),
176        }
177    }
178}
179
180impl std::error::Error for Error {
181    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
182        match self {
183            Error::Config(e) => Some(e),
184            Error::Io(e) => Some(e),
185            Error::Unsupported(_) => None,
186            #[cfg(feature = "tls")]
187            Error::Tls(e) => Some(e),
188        }
189    }
190}