rocket_ws/
websocket.rs

1use std::io;
2
3use rocket::data::{IoHandler, IoStream};
4use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
5use rocket::response::{self, Responder, Response};
6use rocket::request::{FromRequest, Request, Outcome};
7use rocket::http::Status;
8
9use crate::{Config, Message};
10use crate::stream::DuplexStream;
11use crate::result::{Result, Error};
12
13/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
14/// or [`MessageStream`].
15///
16/// For example usage, see the [crate docs](crate#usage).
17///
18/// ## Details
19///
20/// This is the entrypoint to the library. Every WebSocket response _must_
21/// initiate via the `WebSocket` request guard. The guard identifies valid
22/// WebSocket connection requests and, if the request is valid, succeeds to be
23/// converted into a streaming WebSocket response via
24/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
25/// [`WebSocket::stream()`]. The connection can be configured via
26/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
27/// connection.
28///
29/// ### Forwarding
30///
31/// If the incoming request is not a valid WebSocket request, the guard
32/// forwards with a status of `BadRequest`. The guard never fails.
33pub struct WebSocket {
34    config: Config,
35    key: String,
36}
37
38impl WebSocket {
39    /// Change the default connection configuration to `config`.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # use rocket::get;
45    /// # use rocket_ws as ws;
46    /// #
47    /// #[get("/echo")]
48    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
49    ///     let ws = ws.config(ws::Config {
50    ///         max_send_queue: Some(5),
51    ///         ..Default::default()
52    ///     });
53    ///
54    ///     ws::Stream! { ws =>
55    ///         for await message in ws {
56    ///             yield message?;
57    ///         }
58    ///     }
59    /// }
60    /// ```
61    pub fn config(mut self, config: Config) -> Self {
62        self.config = config;
63        self
64    }
65
66    /// Create a read/write channel to the client and call `handler` with it.
67    ///
68    /// This method takes a `FnOnce`, `handler`, that consumes a read/write
69    /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
70    /// for details on how to make use of the channel.
71    ///
72    /// The `handler` must return a `Box`ed and `Pin`ned future: calling
73    /// [`Box::pin()`] with a future does just this as is the preferred
74    /// mechanism to create a `Box<Pin<Future>>`. The future must return a
75    /// [`Result<()>`](crate::result::Result). The WebSocket connection is
76    /// closed successfully if the future returns `Ok` and with an error if
77    /// the future returns `Err`.
78    ///
79    /// # Lifetimes
80    ///
81    /// The `Channel` may borrow from the request. If it does, the lifetime
82    /// should be specified as something other than `'static`. Otherwise, the
83    /// `'static` lifetime should be used.
84    ///
85    /// # Example
86    ///
87    /// ```rust
88    /// # use rocket::get;
89    /// # use rocket_ws as ws;
90    /// use rocket::futures::{SinkExt, StreamExt};
91    ///
92    /// #[get("/hello/<name>")]
93    /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
94    ///     ws.channel(move |mut stream| Box::pin(async move {
95    ///         let message = format!("Hello, {}!", name);
96    ///         let _ = stream.send(message.into()).await;
97    ///         Ok(())
98    ///     }))
99    /// }
100    ///
101    /// #[get("/echo")]
102    /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
103    ///     ws.channel(move |mut stream| Box::pin(async move {
104    ///         while let Some(message) = stream.next().await {
105    ///             let _ = stream.send(message?).await;
106    ///         }
107    ///
108    ///         Ok(())
109    ///     }))
110    /// }
111    /// ```
112    pub fn channel<'r, F>(self, handler: F) -> Channel<'r>
113        where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r
114    {
115        Channel { ws: self, handler: Box::new(handler), }
116    }
117
118    /// Create a stream that consumes client [`Message`]s and emits its own.
119    ///
120    /// This method takes a `FnOnce` `stream` that consumes a read-only stream
121    /// and returns a stream of [`Message`]s. While the returned stream can be
122    /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
123    /// preferred method. In any case, the stream must be `Send`.
124    ///
125    /// The returned stream must emit items of type `Result<Message>`. Items
126    /// that are `Ok(Message)` are sent to the client while items of type
127    /// `Err(Error)` result in the connection being closed and the remainder of
128    /// the stream discarded.
129    ///
130    /// # Example
131    ///
132    /// ```rust
133    /// # use rocket::get;
134    /// # use rocket_ws as ws;
135    ///
136    /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
137    /// #[get("/echo?stream")]
138    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
139    ///     ws::Stream! { ws =>
140    ///         for await message in ws {
141    ///             yield message?;
142    ///         }
143    ///     }
144    /// }
145    ///
146    /// // Use a raw stream.
147    /// #[get("/echo?compose")]
148    /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
149    ///     ws.stream(|io| io)
150    /// }
151    /// ```
152    pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
153        where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
154              S: futures::Stream<Item = Result<Message>> + Send + 'r
155    {
156        MessageStream { ws: self, handler: Box::new(stream), }
157    }
158
159    /// Returns the server's fully computed and encoded WebSocket handshake
160    /// accept key.
161    ///
162    /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
163    /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
164    /// > SHA-1 of the new value, and is then base64 encoded.
165    /// >
166    /// > -- [`Sec-WebSocket-Accept`]
167    ///
168    /// This is the value returned via the [`Sec-WebSocket-Accept`] header
169    /// during the acceptance response.
170    ///
171    /// [`Sec-WebSocket-Accept`]:
172    /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
173    ///
174    /// # Example
175    ///
176    /// ```rust
177    /// # use rocket::get;
178    /// # use rocket_ws as ws;
179    /// #
180    /// #[get("/echo")]
181    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
182    ///     let accept_key = ws.accept_key();
183    ///     ws.stream(|io| io)
184    /// }
185    /// ```
186    pub fn accept_key(&self) -> &str {
187        &self.key
188    }
189
190}
191
192/// A streaming channel, returned by [`WebSocket::channel()`].
193///
194/// `Channel` has no methods or functionality beyond its trait implementations.
195pub struct Channel<'r> {
196    ws: WebSocket,
197    handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
198}
199
200/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
201/// [`WebSocket::stream()`], used via [`Stream!`].
202///
203/// This type should not be used directly. Instead, it is used via the
204/// [`Stream!`] macro, which expands to both the type itself and an expression
205/// which evaluates to this type. See [`Stream!`] for details.
206///
207/// [`Stream!`]: crate::Stream!
208// TODO: Get rid of this or `Channel` via a single `enum`.
209pub struct MessageStream<'r, S> {
210    ws: WebSocket,
211    handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
212}
213
214#[rocket::async_trait]
215impl<'r> FromRequest<'r> for WebSocket {
216    type Error = std::convert::Infallible;
217
218    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
219        use crate::tungstenite::handshake::derive_accept_key;
220        use rocket::http::uncased::eq;
221
222        let headers = req.headers();
223        let is_upgrade = headers.get("Connection")
224            .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
225
226        let is_ws = headers.get("Upgrade")
227            .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
228
229        let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
230        let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
231        match key {
232            Some(key) if is_upgrade && is_ws && is_13 => {
233                Outcome::Success(WebSocket { key, config: Config::default() })
234            },
235            Some(_) | None => Outcome::Forward(Status::BadRequest)
236        }
237    }
238}
239
240impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
241    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
242        Response::build()
243            .raw_header("Sec-Websocket-Version", "13")
244            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
245            .upgrade("websocket", self)
246            .ok()
247    }
248}
249
250impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
251    where S: futures::Stream<Item = Result<Message>> + Send + 'o
252{
253    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
254        Response::build()
255            .raw_header("Sec-Websocket-Version", "13")
256            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
257            .upgrade("websocket", self)
258            .ok()
259    }
260}
261
262#[rocket::async_trait]
263impl IoHandler for Channel<'_> {
264    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
265        let stream = DuplexStream::new(io, self.ws.config).await;
266        let result = (self.handler)(stream).await;
267        handle_result(result).map(|_| ())
268    }
269}
270
271#[rocket::async_trait]
272impl<'r, S> IoHandler for MessageStream<'r, S>
273    where S: futures::Stream<Item = Result<Message>> + Send + 'r
274{
275    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
276        let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
277        let stream = (self.handler)(source);
278        rocket::tokio::pin!(stream);
279        while let Some(msg) = stream.next().await {
280            let result = match msg {
281                Ok(msg) if msg.is_close() => return Ok(()),
282                Ok(msg) => sink.send(msg).await,
283                Err(e) => Err(e)
284            };
285
286            if !handle_result(result)? {
287                return Ok(());
288            }
289        }
290
291        Ok(())
292    }
293}
294
295/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
296/// has terminated without error, and `Err(e)` if an error has occurred.
297fn handle_result(result: Result<()>) -> io::Result<bool> {
298    match result {
299        Ok(_) => Ok(true),
300        Err(Error::ConnectionClosed) => Ok(false),
301        Err(Error::Io(e)) => Err(e),
302        Err(e) => Err(io::Error::new(io::ErrorKind::Other, e))
303    }
304}