rocket_ws/websocket.rs
1use std::io;
2
3use rocket::data::{IoHandler, IoStream};
4use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
5use rocket::response::{self, Responder, Response};
6use rocket::request::{FromRequest, Request, Outcome};
7use rocket::http::Status;
8
9use crate::{Config, Message};
10use crate::stream::DuplexStream;
11use crate::result::{Result, Error};
12
13/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
14/// or [`MessageStream`].
15///
16/// For example usage, see the [crate docs](crate#usage).
17///
18/// ## Details
19///
20/// This is the entrypoint to the library. Every WebSocket response _must_
21/// initiate via the `WebSocket` request guard. The guard identifies valid
22/// WebSocket connection requests and, if the request is valid, succeeds to be
23/// converted into a streaming WebSocket response via
24/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
25/// [`WebSocket::stream()`]. The connection can be configured via
26/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
27/// connection.
28///
29/// ### Forwarding
30///
31/// If the incoming request is not a valid WebSocket request, the guard
32/// forwards with a status of `BadRequest`. The guard never fails.
33pub struct WebSocket {
34 config: Config,
35 key: String,
36}
37
38impl WebSocket {
39 /// Change the default connection configuration to `config`.
40 ///
41 /// # Example
42 ///
43 /// ```rust
44 /// # use rocket::get;
45 /// # use rocket_ws as ws;
46 /// #
47 /// #[get("/echo")]
48 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
49 /// let ws = ws.config(ws::Config {
50 /// max_send_queue: Some(5),
51 /// ..Default::default()
52 /// });
53 ///
54 /// ws::Stream! { ws =>
55 /// for await message in ws {
56 /// yield message?;
57 /// }
58 /// }
59 /// }
60 /// ```
61 pub fn config(mut self, config: Config) -> Self {
62 self.config = config;
63 self
64 }
65
66 /// Create a read/write channel to the client and call `handler` with it.
67 ///
68 /// This method takes a `FnOnce`, `handler`, that consumes a read/write
69 /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
70 /// for details on how to make use of the channel.
71 ///
72 /// The `handler` must return a `Box`ed and `Pin`ned future: calling
73 /// [`Box::pin()`] with a future does just this as is the preferred
74 /// mechanism to create a `Box<Pin<Future>>`. The future must return a
75 /// [`Result<()>`](crate::result::Result). The WebSocket connection is
76 /// closed successfully if the future returns `Ok` and with an error if
77 /// the future returns `Err`.
78 ///
79 /// # Lifetimes
80 ///
81 /// The `Channel` may borrow from the request. If it does, the lifetime
82 /// should be specified as something other than `'static`. Otherwise, the
83 /// `'static` lifetime should be used.
84 ///
85 /// # Example
86 ///
87 /// ```rust
88 /// # use rocket::get;
89 /// # use rocket_ws as ws;
90 /// use rocket::futures::{SinkExt, StreamExt};
91 ///
92 /// #[get("/hello/<name>")]
93 /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
94 /// ws.channel(move |mut stream| Box::pin(async move {
95 /// let message = format!("Hello, {}!", name);
96 /// let _ = stream.send(message.into()).await;
97 /// Ok(())
98 /// }))
99 /// }
100 ///
101 /// #[get("/echo")]
102 /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
103 /// ws.channel(move |mut stream| Box::pin(async move {
104 /// while let Some(message) = stream.next().await {
105 /// let _ = stream.send(message?).await;
106 /// }
107 ///
108 /// Ok(())
109 /// }))
110 /// }
111 /// ```
112 pub fn channel<'r, F>(self, handler: F) -> Channel<'r>
113 where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r
114 {
115 Channel { ws: self, handler: Box::new(handler), }
116 }
117
118 /// Create a stream that consumes client [`Message`]s and emits its own.
119 ///
120 /// This method takes a `FnOnce` `stream` that consumes a read-only stream
121 /// and returns a stream of [`Message`]s. While the returned stream can be
122 /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
123 /// preferred method. In any case, the stream must be `Send`.
124 ///
125 /// The returned stream must emit items of type `Result<Message>`. Items
126 /// that are `Ok(Message)` are sent to the client while items of type
127 /// `Err(Error)` result in the connection being closed and the remainder of
128 /// the stream discarded.
129 ///
130 /// # Example
131 ///
132 /// ```rust
133 /// # use rocket::get;
134 /// # use rocket_ws as ws;
135 ///
136 /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
137 /// #[get("/echo?stream")]
138 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
139 /// ws::Stream! { ws =>
140 /// for await message in ws {
141 /// yield message?;
142 /// }
143 /// }
144 /// }
145 ///
146 /// // Use a raw stream.
147 /// #[get("/echo?compose")]
148 /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
149 /// ws.stream(|io| io)
150 /// }
151 /// ```
152 pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
153 where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
154 S: futures::Stream<Item = Result<Message>> + Send + 'r
155 {
156 MessageStream { ws: self, handler: Box::new(stream), }
157 }
158
159 /// Returns the server's fully computed and encoded WebSocket handshake
160 /// accept key.
161 ///
162 /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
163 /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
164 /// > SHA-1 of the new value, and is then base64 encoded.
165 /// >
166 /// > -- [`Sec-WebSocket-Accept`]
167 ///
168 /// This is the value returned via the [`Sec-WebSocket-Accept`] header
169 /// during the acceptance response.
170 ///
171 /// [`Sec-WebSocket-Accept`]:
172 /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
173 ///
174 /// # Example
175 ///
176 /// ```rust
177 /// # use rocket::get;
178 /// # use rocket_ws as ws;
179 /// #
180 /// #[get("/echo")]
181 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
182 /// let accept_key = ws.accept_key();
183 /// ws.stream(|io| io)
184 /// }
185 /// ```
186 pub fn accept_key(&self) -> &str {
187 &self.key
188 }
189
190}
191
192/// A streaming channel, returned by [`WebSocket::channel()`].
193///
194/// `Channel` has no methods or functionality beyond its trait implementations.
195pub struct Channel<'r> {
196 ws: WebSocket,
197 handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
198}
199
200/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
201/// [`WebSocket::stream()`], used via [`Stream!`].
202///
203/// This type should not be used directly. Instead, it is used via the
204/// [`Stream!`] macro, which expands to both the type itself and an expression
205/// which evaluates to this type. See [`Stream!`] for details.
206///
207/// [`Stream!`]: crate::Stream!
208// TODO: Get rid of this or `Channel` via a single `enum`.
209pub struct MessageStream<'r, S> {
210 ws: WebSocket,
211 handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
212}
213
214#[rocket::async_trait]
215impl<'r> FromRequest<'r> for WebSocket {
216 type Error = std::convert::Infallible;
217
218 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
219 use crate::tungstenite::handshake::derive_accept_key;
220 use rocket::http::uncased::eq;
221
222 let headers = req.headers();
223 let is_upgrade = headers.get("Connection")
224 .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
225
226 let is_ws = headers.get("Upgrade")
227 .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
228
229 let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
230 let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
231 match key {
232 Some(key) if is_upgrade && is_ws && is_13 => {
233 Outcome::Success(WebSocket { key, config: Config::default() })
234 },
235 Some(_) | None => Outcome::Forward(Status::BadRequest)
236 }
237 }
238}
239
240impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
241 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
242 Response::build()
243 .raw_header("Sec-Websocket-Version", "13")
244 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
245 .upgrade("websocket", self)
246 .ok()
247 }
248}
249
250impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
251 where S: futures::Stream<Item = Result<Message>> + Send + 'o
252{
253 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
254 Response::build()
255 .raw_header("Sec-Websocket-Version", "13")
256 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
257 .upgrade("websocket", self)
258 .ok()
259 }
260}
261
262#[rocket::async_trait]
263impl IoHandler for Channel<'_> {
264 async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
265 let stream = DuplexStream::new(io, self.ws.config).await;
266 let result = (self.handler)(stream).await;
267 handle_result(result).map(|_| ())
268 }
269}
270
271#[rocket::async_trait]
272impl<'r, S> IoHandler for MessageStream<'r, S>
273 where S: futures::Stream<Item = Result<Message>> + Send + 'r
274{
275 async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
276 let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
277 let stream = (self.handler)(source);
278 rocket::tokio::pin!(stream);
279 while let Some(msg) = stream.next().await {
280 let result = match msg {
281 Ok(msg) if msg.is_close() => return Ok(()),
282 Ok(msg) => sink.send(msg).await,
283 Err(e) => Err(e)
284 };
285
286 if !handle_result(result)? {
287 return Ok(());
288 }
289 }
290
291 Ok(())
292 }
293}
294
295/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
296/// has terminated without error, and `Err(e)` if an error has occurred.
297fn handle_result(result: Result<()>) -> io::Result<bool> {
298 match result {
299 Ok(_) => Ok(true),
300 Err(Error::ConnectionClosed) => Ok(false),
301 Err(Error::Io(e)) => Err(e),
302 Err(e) => Err(io::Error::new(io::ErrorKind::Other, e))
303 }
304}