1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
use std::io;
use std::pin::Pin;

use rocket::data::{IoHandler, IoStream};
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
use rocket::response::{self, Responder, Response};
use rocket::request::{FromRequest, Request, Outcome};
use rocket::http::Status;

use crate::{Config, Message};
use crate::stream::DuplexStream;
use crate::result::{Result, Error};

/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
/// or [`MessageStream`].
///
/// For example usage, see the [crate docs](crate#usage).
///
/// ## Details
///
/// This is the entrypoint to the library. Every WebSocket response _must_
/// initiate via the `WebSocket` request guard. The guard identifies valid
/// WebSocket connection requests and, if the request is valid, succeeds to be
/// converted into a streaming WebSocket response via
/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
/// [`WebSocket::stream()`]. The connection can be configured via
/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
/// connection.
///
/// ### Forwarding
///
/// If the incoming request is not a valid WebSocket request, the guard
/// forwards with a status of `BadRequest`. The guard never fails.
pub struct WebSocket {
    config: Config,
    key: String,
}

impl WebSocket {
    fn new(key: String) -> WebSocket {
        WebSocket { config: Config::default(), key }
    }

    /// Change the default connection configuration to `config`.
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rocket::get;
    /// # use rocket_ws as ws;
    /// #
    /// #[get("/echo")]
    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
    ///     let ws = ws.config(ws::Config {
    ///         max_send_queue: Some(5),
    ///         ..Default::default()
    ///     });
    ///
    ///     ws::Stream! { ws =>
    ///         for await message in ws {
    ///             yield message?;
    ///         }
    ///     }
    /// }
    /// ```
    pub fn config(mut self, config: Config) -> Self {
        self.config = config;
        self
    }

    /// Create a read/write channel to the client and call `handler` with it.
    ///
    /// This method takes a `FnOnce`, `handler`, that consumes a read/write
    /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
    /// for details on how to make use of the channel.
    ///
    /// The `handler` must return a `Box`ed and `Pin`ned future: calling
    /// [`Box::pin()`] with a future does just this as is the preferred
    /// mechanism to create a `Box<Pin<Future>>`. The future must return a
    /// [`Result<()>`](crate::result::Result). The WebSocket connection is
    /// closed successfully if the future returns `Ok` and with an error if
    /// the future returns `Err`.
    ///
    /// # Lifetimes
    ///
    /// The `Channel` may borrow from the request. If it does, the lifetime
    /// should be specified as something other than `'static`. Otherwise, the
    /// `'static` lifetime should be used.
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rocket::get;
    /// # use rocket_ws as ws;
    /// use rocket::futures::{SinkExt, StreamExt};
    ///
    /// #[get("/hello/<name>")]
    /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
    ///     ws.channel(move |mut stream| Box::pin(async move {
    ///         let message = format!("Hello, {}!", name);
    ///         let _ = stream.send(message.into()).await;
    ///         Ok(())
    ///     }))
    /// }
    ///
    /// #[get("/echo")]
    /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
    ///     ws.channel(move |mut stream| Box::pin(async move {
    ///         while let Some(message) = stream.next().await {
    ///             let _ = stream.send(message?).await;
    ///         }
    ///
    ///         Ok(())
    ///     }))
    /// }
    /// ```
    pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
        where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
    {
        Channel { ws: self, handler: Box::new(handler), }
    }

    /// Create a stream that consumes client [`Message`]s and emits its own.
    ///
    /// This method takes a `FnOnce` `stream` that consumes a read-only stream
    /// and returns a stream of [`Message`]s. While the returned stream can be
    /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
    /// preferred method. In any case, the stream must be `Send`.
    ///
    /// The returned stream must emit items of type `Result<Message>`. Items
    /// that are `Ok(Message)` are sent to the client while items of type
    /// `Err(Error)` result in the connection being closed and the remainder of
    /// the stream discarded.
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rocket::get;
    /// # use rocket_ws as ws;
    ///
    /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
    /// #[get("/echo?stream")]
    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
    ///     ws::Stream! { ws =>
    ///         for await message in ws {
    ///             yield message?;
    ///         }
    ///     }
    /// }
    ///
    /// // Use a raw stream.
    /// #[get("/echo?compose")]
    /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
    ///     ws.stream(|io| io)
    /// }
    /// ```
    pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
        where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
              S: futures::Stream<Item = Result<Message>> + Send + 'r
    {
        MessageStream { ws: self, handler: Box::new(stream), }
    }

    /// Returns the server's fully computed and encoded WebSocket handshake
    /// accept key.
    ///
    /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
    /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
    /// > SHA-1 of the new value, and is then base64 encoded.
    /// >
    /// > -- [`Sec-WebSocket-Accept`]
    ///
    /// This is the value returned via the [`Sec-WebSocket-Accept`] header
    /// during the acceptance response.
    ///
    /// [`Sec-WebSocket-Accept`]:
    /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
    ///
    /// # Example
    ///
    /// ```rust
    /// # use rocket::get;
    /// # use rocket_ws as ws;
    /// #
    /// #[get("/echo")]
    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
    ///     let accept_key = ws.accept_key();
    ///     ws.stream(|io| io)
    /// }
    /// ```
    pub fn accept_key(&self) -> &str {
        &self.key
    }

}

/// A streaming channel, returned by [`WebSocket::channel()`].
///
/// `Channel` has no methods or functionality beyond its trait implementations.
pub struct Channel<'r> {
    ws: WebSocket,
    handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
}

/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
/// [`WebSocket::stream()`], used via [`Stream!`].
///
/// This type should not be used directly. Instead, it is used via the
/// [`Stream!`] macro, which expands to both the type itself and an expression
/// which evaluates to this type. See [`Stream!`] for details.
///
/// [`Stream!`]: crate::Stream!
// TODO: Get rid of this or `Channel` via a single `enum`.
pub struct MessageStream<'r, S> {
    ws: WebSocket,
    handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for WebSocket {
    type Error = std::convert::Infallible;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        use crate::tungstenite::handshake::derive_accept_key;
        use rocket::http::uncased::eq;

        let headers = req.headers();
        let is_upgrade = headers.get("Connection")
            .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));

        let is_ws = headers.get("Upgrade")
            .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));

        let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
        let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
        match key {
            Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
            Some(_) | None => Outcome::Forward(Status::BadRequest)
        }
    }
}

impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
        Response::build()
            .raw_header("Sec-Websocket-Version", "13")
            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
            .upgrade("websocket", self)
            .ok()
    }
}

impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
    where S: futures::Stream<Item = Result<Message>> + Send + 'o
{
    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
        Response::build()
            .raw_header("Sec-Websocket-Version", "13")
            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
            .upgrade("websocket", self)
            .ok()
    }
}

#[rocket::async_trait]
impl IoHandler for Channel<'_> {
    async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
        let channel = Pin::into_inner(self);
        let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
        handle_result(result).map(|_| ())
    }
}

#[rocket::async_trait]
impl<'r, S> IoHandler for MessageStream<'r, S>
    where S: futures::Stream<Item = Result<Message>> + Send + 'r
{
    async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
        let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
        let stream = (Pin::into_inner(self).handler)(source);
        rocket::tokio::pin!(stream);
        while let Some(msg) = stream.next().await {
            let result = match msg {
                Ok(msg) => sink.send(msg).await,
                Err(e) => Err(e)
            };

            if !handle_result(result)? {
                return Ok(());
            }
        }

        Ok(())
    }
}

/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
/// has terminated without error, and `Err(e)` if an error has occurred.
fn handle_result(result: Result<()>) -> io::Result<bool> {
    match result {
        Ok(_) => Ok(true),
        Err(Error::ConnectionClosed) => Ok(false),
        Err(Error::Io(e)) => Err(e),
        Err(e) => Err(io::Error::new(io::ErrorKind::Other, e))
    }
}