1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use std::io;
use std::task::{Context, Poll};
use std::pin::Pin;

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;

/// A bidirectional, raw stream to the client.
///
/// An instance of `IoStream` is passed to an [`IoHandler`] in response to a
/// successful upgrade request initiated by responders via
/// [`Response::add_upgrade()`] or the equivalent builder method
/// [`Builder::upgrade()`]. For details on upgrade connections, see
/// [`Response`#upgrading].
///
/// An `IoStream` is guaranteed to be [`AsyncRead`], [`AsyncWrite`], and
/// `Unpin`. Bytes written to the stream are sent directly to the client. Bytes
/// read from the stream are those sent directly _by_ the client. See
/// [`IoHandler`] for one example of how values of this type are used.
///
/// [`Response::add_upgrade()`]: crate::Response::add_upgrade()
/// [`Builder::upgrade()`]: crate::response::Builder::upgrade()
/// [`Response`#upgrading]: crate::response::Response#upgrading
pub struct IoStream {
    kind: IoStreamKind,
}

/// Just in case we want to add stream kinds in the future.
enum IoStreamKind {
    Upgraded(TokioIo<Upgraded>)
}

/// An upgraded connection I/O handler.
///
/// An I/O handler performs raw I/O via the passed in [`IoStream`], which is
/// [`AsyncRead`], [`AsyncWrite`], and `Unpin`.
///
/// # Example
///
/// The example below implements an `EchoHandler` that echos the raw bytes back
/// to the client.
///
/// ```rust
/// use std::pin::Pin;
///
/// use rocket::tokio::io;
/// use rocket::data::{IoHandler, IoStream};
///
/// struct EchoHandler;
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
///     async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
///         let (mut reader, mut writer) = io::split(io);
///         io::copy(&mut reader, &mut writer).await?;
///         Ok(())
///     }
/// }
///
/// # use rocket::Response;
/// # rocket::async_test(async {
/// # let mut response = Response::new();
/// # response.add_upgrade("raw-echo", EchoHandler);
/// # assert!(response.upgrade("raw-echo").is_some());
/// # })
/// ```
#[crate::async_trait]
pub trait IoHandler: Send {
    /// Performs the raw I/O.
    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
}

#[crate::async_trait]
impl IoHandler for () {
    async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
        Ok(())
    }
}

#[doc(hidden)]
impl From<Upgraded> for IoStream {
    fn from(io: Upgraded) -> Self {
        IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) }
    }
}

/// A "trait alias" of sorts so we can use `AsyncRead + AsyncWrite + Unpin` in `dyn`.
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin { }

/// Implemented for all `AsyncRead + AsyncWrite + Unpin`, of course.
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncReadWrite for T {  }

impl IoStream {
    /// Returns the internal I/O stream.
    fn inner_mut(&mut self) -> Pin<&mut dyn AsyncReadWrite> {
        match self.kind {
            IoStreamKind::Upgraded(ref mut io) => Pin::new(io),
        }
    }

    /// Returns `true` if the inner I/O stream is write vectored.
    fn inner_is_write_vectored(&self) -> bool {
        match self.kind {
            IoStreamKind::Upgraded(ref io) => io.is_write_vectored(),
        }
    }
}

impl AsyncRead for IoStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        self.get_mut().inner_mut().poll_read(cx, buf)
    }
}

impl AsyncWrite for IoStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        self.get_mut().inner_mut().poll_write(cx, buf)
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.get_mut().inner_mut().poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.get_mut().inner_mut().poll_shutdown(cx)
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[io::IoSlice<'_>],
    ) -> Poll<io::Result<usize>> {
        self.get_mut().inner_mut().poll_write_vectored(cx, bufs)
    }

    fn is_write_vectored(&self) -> bool {
        self.inner_is_write_vectored()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn is_unpin() {
        fn check_traits<T: AsyncRead + AsyncWrite + Unpin + Send>() {}
        check_traits::<IoStream>();
    }
}