rocket_ws/websocket.rs
1use std::io;
2use std::pin::Pin;
3
4use rocket::data::{IoHandler, IoStream};
5use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
6use rocket::response::{self, Responder, Response};
7use rocket::request::{FromRequest, Request, Outcome};
8use rocket::http::Status;
9
10use crate::{Config, Message};
11use crate::stream::DuplexStream;
12use crate::result::{Result, Error};
13
14/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
15/// or [`MessageStream`].
16///
17/// For example usage, see the [crate docs](crate#usage).
18///
19/// ## Details
20///
21/// This is the entrypoint to the library. Every WebSocket response _must_
22/// initiate via the `WebSocket` request guard. The guard identifies valid
23/// WebSocket connection requests and, if the request is valid, succeeds to be
24/// converted into a streaming WebSocket response via
25/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
26/// [`WebSocket::stream()`]. The connection can be configured via
27/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
28/// connection.
29///
30/// ### Forwarding
31///
32/// If the incoming request is not a valid WebSocket request, the guard
33/// forwards with a status of `BadRequest`. The guard never fails.
34pub struct WebSocket {
35 config: Config,
36 key: String,
37}
38
39impl WebSocket {
40 fn new(key: String) -> WebSocket {
41 WebSocket { config: Config::default(), key }
42 }
43
44 /// Change the default connection configuration to `config`.
45 ///
46 /// # Example
47 ///
48 /// ```rust
49 /// # use rocket::get;
50 /// # use rocket_ws as ws;
51 /// #
52 /// #[get("/echo")]
53 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
54 /// let ws = ws.config(ws::Config {
55 /// max_send_queue: Some(5),
56 /// ..Default::default()
57 /// });
58 ///
59 /// ws::Stream! { ws =>
60 /// for await message in ws {
61 /// yield message?;
62 /// }
63 /// }
64 /// }
65 /// ```
66 pub fn config(mut self, config: Config) -> Self {
67 self.config = config;
68 self
69 }
70
71 /// Create a read/write channel to the client and call `handler` with it.
72 ///
73 /// This method takes a `FnOnce`, `handler`, that consumes a read/write
74 /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
75 /// for details on how to make use of the channel.
76 ///
77 /// The `handler` must return a `Box`ed and `Pin`ned future: calling
78 /// [`Box::pin()`] with a future does just this as is the preferred
79 /// mechanism to create a `Box<Pin<Future>>`. The future must return a
80 /// [`Result<()>`](crate::result::Result). The WebSocket connection is
81 /// closed successfully if the future returns `Ok` and with an error if
82 /// the future returns `Err`.
83 ///
84 /// # Lifetimes
85 ///
86 /// The `Channel` may borrow from the request. If it does, the lifetime
87 /// should be specified as something other than `'static`. Otherwise, the
88 /// `'static` lifetime should be used.
89 ///
90 /// # Example
91 ///
92 /// ```rust
93 /// # use rocket::get;
94 /// # use rocket_ws as ws;
95 /// use rocket::futures::{SinkExt, StreamExt};
96 ///
97 /// #[get("/hello/<name>")]
98 /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
99 /// ws.channel(move |mut stream| Box::pin(async move {
100 /// let message = format!("Hello, {}!", name);
101 /// let _ = stream.send(message.into()).await;
102 /// Ok(())
103 /// }))
104 /// }
105 ///
106 /// #[get("/echo")]
107 /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
108 /// ws.channel(move |mut stream| Box::pin(async move {
109 /// while let Some(message) = stream.next().await {
110 /// let _ = stream.send(message?).await;
111 /// }
112 ///
113 /// Ok(())
114 /// }))
115 /// }
116 /// ```
117 pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
118 where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
119 {
120 Channel { ws: self, handler: Box::new(handler), }
121 }
122
123 /// Create a stream that consumes client [`Message`]s and emits its own.
124 ///
125 /// This method takes a `FnOnce` `stream` that consumes a read-only stream
126 /// and returns a stream of [`Message`]s. While the returned stream can be
127 /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
128 /// preferred method. In any case, the stream must be `Send`.
129 ///
130 /// The returned stream must emit items of type `Result<Message>`. Items
131 /// that are `Ok(Message)` are sent to the client while items of type
132 /// `Err(Error)` result in the connection being closed and the remainder of
133 /// the stream discarded.
134 ///
135 /// # Example
136 ///
137 /// ```rust
138 /// # use rocket::get;
139 /// # use rocket_ws as ws;
140 ///
141 /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
142 /// #[get("/echo?stream")]
143 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
144 /// ws::Stream! { ws =>
145 /// for await message in ws {
146 /// yield message?;
147 /// }
148 /// }
149 /// }
150 ///
151 /// // Use a raw stream.
152 /// #[get("/echo?compose")]
153 /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
154 /// ws.stream(|io| io)
155 /// }
156 /// ```
157 pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
158 where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
159 S: futures::Stream<Item = Result<Message>> + Send + 'r
160 {
161 MessageStream { ws: self, handler: Box::new(stream), }
162 }
163
164 /// Returns the server's fully computed and encoded WebSocket handshake
165 /// accept key.
166 ///
167 /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
168 /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
169 /// > SHA-1 of the new value, and is then base64 encoded.
170 /// >
171 /// > -- [`Sec-WebSocket-Accept`]
172 ///
173 /// This is the value returned via the [`Sec-WebSocket-Accept`] header
174 /// during the acceptance response.
175 ///
176 /// [`Sec-WebSocket-Accept`]:
177 /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
178 ///
179 /// # Example
180 ///
181 /// ```rust
182 /// # use rocket::get;
183 /// # use rocket_ws as ws;
184 /// #
185 /// #[get("/echo")]
186 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
187 /// let accept_key = ws.accept_key();
188 /// ws.stream(|io| io)
189 /// }
190 /// ```
191 pub fn accept_key(&self) -> &str {
192 &self.key
193 }
194
195}
196
197/// A streaming channel, returned by [`WebSocket::channel()`].
198///
199/// `Channel` has no methods or functionality beyond its trait implementations.
200pub struct Channel<'r> {
201 ws: WebSocket,
202 handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
203}
204
205/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
206/// [`WebSocket::stream()`], used via [`Stream!`].
207///
208/// This type should not be used directly. Instead, it is used via the
209/// [`Stream!`] macro, which expands to both the type itself and an expression
210/// which evaluates to this type. See [`Stream!`] for details.
211///
212/// [`Stream!`]: crate::Stream!
213// TODO: Get rid of this or `Channel` via a single `enum`.
214pub struct MessageStream<'r, S> {
215 ws: WebSocket,
216 handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
217}
218
219#[rocket::async_trait]
220impl<'r> FromRequest<'r> for WebSocket {
221 type Error = std::convert::Infallible;
222
223 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
224 use crate::tungstenite::handshake::derive_accept_key;
225 use rocket::http::uncased::eq;
226
227 let headers = req.headers();
228 let is_upgrade = headers.get("Connection")
229 .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
230
231 let is_ws = headers.get("Upgrade")
232 .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
233
234 let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
235 let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
236 match key {
237 Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
238 Some(_) | None => Outcome::Forward(Status::BadRequest)
239 }
240 }
241}
242
243impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
244 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
245 Response::build()
246 .raw_header("Sec-Websocket-Version", "13")
247 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
248 .upgrade("websocket", self)
249 .ok()
250 }
251}
252
253impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
254 where S: futures::Stream<Item = Result<Message>> + Send + 'o
255{
256 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
257 Response::build()
258 .raw_header("Sec-Websocket-Version", "13")
259 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
260 .upgrade("websocket", self)
261 .ok()
262 }
263}
264
265#[rocket::async_trait]
266impl IoHandler for Channel<'_> {
267 async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
268 let channel = Pin::into_inner(self);
269 let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
270 handle_result(result).map(|_| ())
271 }
272}
273
274#[rocket::async_trait]
275impl<'r, S> IoHandler for MessageStream<'r, S>
276 where S: futures::Stream<Item = Result<Message>> + Send + 'r
277{
278 async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
279 let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
280 let stream = (Pin::into_inner(self).handler)(source);
281 rocket::tokio::pin!(stream);
282 while let Some(msg) = stream.next().await {
283 let result = match msg {
284 Ok(msg) => sink.send(msg).await,
285 Err(e) => Err(e)
286 };
287
288 if !handle_result(result)? {
289 return Ok(());
290 }
291 }
292
293 Ok(())
294 }
295}
296
297/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
298/// has terminated without error, and `Err(e)` if an error has occurred.
299fn handle_result(result: Result<()>) -> io::Result<bool> {
300 match result {
301 Ok(_) => Ok(true),
302 Err(Error::ConnectionClosed) => Ok(false),
303 Err(Error::Io(e)) => Err(e),
304 Err(e) => Err(io::Error::new(io::ErrorKind::Other, e))
305 }
306}