rocket_ws/
websocket.rs

1use std::io;
2use std::pin::Pin;
3
4use rocket::data::{IoHandler, IoStream};
5use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
6use rocket::response::{self, Responder, Response};
7use rocket::request::{FromRequest, Request, Outcome};
8use rocket::http::Status;
9
10use crate::{Config, Message};
11use crate::stream::DuplexStream;
12use crate::result::{Result, Error};
13
14/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
15/// or [`MessageStream`].
16///
17/// For example usage, see the [crate docs](crate#usage).
18///
19/// ## Details
20///
21/// This is the entrypoint to the library. Every WebSocket response _must_
22/// initiate via the `WebSocket` request guard. The guard identifies valid
23/// WebSocket connection requests and, if the request is valid, succeeds to be
24/// converted into a streaming WebSocket response via
25/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
26/// [`WebSocket::stream()`]. The connection can be configured via
27/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
28/// connection.
29///
30/// ### Forwarding
31///
32/// If the incoming request is not a valid WebSocket request, the guard
33/// forwards with a status of `BadRequest`. The guard never fails.
34pub struct WebSocket {
35    config: Config,
36    key: String,
37}
38
39impl WebSocket {
40    fn new(key: String) -> WebSocket {
41        WebSocket { config: Config::default(), key }
42    }
43
44    /// Change the default connection configuration to `config`.
45    ///
46    /// # Example
47    ///
48    /// ```rust
49    /// # use rocket::get;
50    /// # use rocket_ws as ws;
51    /// #
52    /// #[get("/echo")]
53    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
54    ///     let ws = ws.config(ws::Config {
55    ///         max_send_queue: Some(5),
56    ///         ..Default::default()
57    ///     });
58    ///
59    ///     ws::Stream! { ws =>
60    ///         for await message in ws {
61    ///             yield message?;
62    ///         }
63    ///     }
64    /// }
65    /// ```
66    pub fn config(mut self, config: Config) -> Self {
67        self.config = config;
68        self
69    }
70
71    /// Create a read/write channel to the client and call `handler` with it.
72    ///
73    /// This method takes a `FnOnce`, `handler`, that consumes a read/write
74    /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
75    /// for details on how to make use of the channel.
76    ///
77    /// The `handler` must return a `Box`ed and `Pin`ned future: calling
78    /// [`Box::pin()`] with a future does just this as is the preferred
79    /// mechanism to create a `Box<Pin<Future>>`. The future must return a
80    /// [`Result<()>`](crate::result::Result). The WebSocket connection is
81    /// closed successfully if the future returns `Ok` and with an error if
82    /// the future returns `Err`.
83    ///
84    /// # Lifetimes
85    ///
86    /// The `Channel` may borrow from the request. If it does, the lifetime
87    /// should be specified as something other than `'static`. Otherwise, the
88    /// `'static` lifetime should be used.
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// # use rocket::get;
94    /// # use rocket_ws as ws;
95    /// use rocket::futures::{SinkExt, StreamExt};
96    ///
97    /// #[get("/hello/<name>")]
98    /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
99    ///     ws.channel(move |mut stream| Box::pin(async move {
100    ///         let message = format!("Hello, {}!", name);
101    ///         let _ = stream.send(message.into()).await;
102    ///         Ok(())
103    ///     }))
104    /// }
105    ///
106    /// #[get("/echo")]
107    /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
108    ///     ws.channel(move |mut stream| Box::pin(async move {
109    ///         while let Some(message) = stream.next().await {
110    ///             let _ = stream.send(message?).await;
111    ///         }
112    ///
113    ///         Ok(())
114    ///     }))
115    /// }
116    /// ```
117    pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
118        where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
119    {
120        Channel { ws: self, handler: Box::new(handler), }
121    }
122
123    /// Create a stream that consumes client [`Message`]s and emits its own.
124    ///
125    /// This method takes a `FnOnce` `stream` that consumes a read-only stream
126    /// and returns a stream of [`Message`]s. While the returned stream can be
127    /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
128    /// preferred method. In any case, the stream must be `Send`.
129    ///
130    /// The returned stream must emit items of type `Result<Message>`. Items
131    /// that are `Ok(Message)` are sent to the client while items of type
132    /// `Err(Error)` result in the connection being closed and the remainder of
133    /// the stream discarded.
134    ///
135    /// # Example
136    ///
137    /// ```rust
138    /// # use rocket::get;
139    /// # use rocket_ws as ws;
140    ///
141    /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
142    /// #[get("/echo?stream")]
143    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
144    ///     ws::Stream! { ws =>
145    ///         for await message in ws {
146    ///             yield message?;
147    ///         }
148    ///     }
149    /// }
150    ///
151    /// // Use a raw stream.
152    /// #[get("/echo?compose")]
153    /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
154    ///     ws.stream(|io| io)
155    /// }
156    /// ```
157    pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
158        where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
159              S: futures::Stream<Item = Result<Message>> + Send + 'r
160    {
161        MessageStream { ws: self, handler: Box::new(stream), }
162    }
163
164    /// Returns the server's fully computed and encoded WebSocket handshake
165    /// accept key.
166    ///
167    /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
168    /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
169    /// > SHA-1 of the new value, and is then base64 encoded.
170    /// >
171    /// > -- [`Sec-WebSocket-Accept`]
172    ///
173    /// This is the value returned via the [`Sec-WebSocket-Accept`] header
174    /// during the acceptance response.
175    ///
176    /// [`Sec-WebSocket-Accept`]:
177    /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
178    ///
179    /// # Example
180    ///
181    /// ```rust
182    /// # use rocket::get;
183    /// # use rocket_ws as ws;
184    /// #
185    /// #[get("/echo")]
186    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
187    ///     let accept_key = ws.accept_key();
188    ///     ws.stream(|io| io)
189    /// }
190    /// ```
191    pub fn accept_key(&self) -> &str {
192        &self.key
193    }
194
195}
196
197/// A streaming channel, returned by [`WebSocket::channel()`].
198///
199/// `Channel` has no methods or functionality beyond its trait implementations.
200pub struct Channel<'r> {
201    ws: WebSocket,
202    handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
203}
204
205/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
206/// [`WebSocket::stream()`], used via [`Stream!`].
207///
208/// This type should not be used directly. Instead, it is used via the
209/// [`Stream!`] macro, which expands to both the type itself and an expression
210/// which evaluates to this type. See [`Stream!`] for details.
211///
212/// [`Stream!`]: crate::Stream!
213// TODO: Get rid of this or `Channel` via a single `enum`.
214pub struct MessageStream<'r, S> {
215    ws: WebSocket,
216    handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
217}
218
219#[rocket::async_trait]
220impl<'r> FromRequest<'r> for WebSocket {
221    type Error = std::convert::Infallible;
222
223    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
224        use crate::tungstenite::handshake::derive_accept_key;
225        use rocket::http::uncased::eq;
226
227        let headers = req.headers();
228        let is_upgrade = headers.get("Connection")
229            .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
230
231        let is_ws = headers.get("Upgrade")
232            .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
233
234        let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
235        let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
236        match key {
237            Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
238            Some(_) | None => Outcome::Forward(Status::BadRequest)
239        }
240    }
241}
242
243impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
244    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
245        Response::build()
246            .raw_header("Sec-Websocket-Version", "13")
247            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
248            .upgrade("websocket", self)
249            .ok()
250    }
251}
252
253impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
254    where S: futures::Stream<Item = Result<Message>> + Send + 'o
255{
256    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
257        Response::build()
258            .raw_header("Sec-Websocket-Version", "13")
259            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
260            .upgrade("websocket", self)
261            .ok()
262    }
263}
264
265#[rocket::async_trait]
266impl IoHandler for Channel<'_> {
267    async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
268        let channel = Pin::into_inner(self);
269        let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
270        handle_result(result).map(|_| ())
271    }
272}
273
274#[rocket::async_trait]
275impl<'r, S> IoHandler for MessageStream<'r, S>
276    where S: futures::Stream<Item = Result<Message>> + Send + 'r
277{
278    async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
279        let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
280        let stream = (Pin::into_inner(self).handler)(source);
281        rocket::tokio::pin!(stream);
282        while let Some(msg) = stream.next().await {
283            let result = match msg {
284                Ok(msg) => sink.send(msg).await,
285                Err(e) => Err(e)
286            };
287
288            if !handle_result(result)? {
289                return Ok(());
290            }
291        }
292
293        Ok(())
294    }
295}
296
297/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
298/// has terminated without error, and `Err(e)` if an error has occurred.
299fn handle_result(result: Result<()>) -> io::Result<bool> {
300    match result {
301        Ok(_) => Ok(true),
302        Err(Error::ConnectionClosed) => Ok(false),
303        Err(Error::Io(e)) => Err(e),
304        Err(e) => Err(io::Error::new(io::ErrorKind::Other, e))
305    }
306}