rocket/data/io_stream.rs
1use std::io;
2use std::task::{Context, Poll};
3use std::pin::Pin;
4
5use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6use hyper::upgrade::Upgraded;
7use hyper_util::rt::TokioIo;
8
9/// A bidirectional, raw stream to the client.
10///
11/// An instance of `IoStream` is passed to an [`IoHandler`] in response to a
12/// successful upgrade request initiated by responders via
13/// [`Response::add_upgrade()`] or the equivalent builder method
14/// [`Builder::upgrade()`]. For details on upgrade connections, see
15/// [`Response`#upgrading].
16///
17/// An `IoStream` is guaranteed to be [`AsyncRead`], [`AsyncWrite`], and
18/// `Unpin`. Bytes written to the stream are sent directly to the client. Bytes
19/// read from the stream are those sent directly _by_ the client. See
20/// [`IoHandler`] for one example of how values of this type are used.
21///
22/// [`Response::add_upgrade()`]: crate::Response::add_upgrade()
23/// [`Builder::upgrade()`]: crate::response::Builder::upgrade()
24/// [`Response`#upgrading]: crate::response::Response#upgrading
25pub struct IoStream {
26 kind: IoStreamKind,
27}
28
29/// Just in case we want to add stream kinds in the future.
30enum IoStreamKind {
31 Upgraded(TokioIo<Upgraded>)
32}
33
34/// An upgraded connection I/O handler.
35///
36/// An I/O handler performs raw I/O via the passed in [`IoStream`], which is
37/// [`AsyncRead`], [`AsyncWrite`], and `Unpin`.
38///
39/// # Example
40///
41/// The example below implements an `EchoHandler` that echos the raw bytes back
42/// to the client.
43///
44/// ```rust
45/// use std::pin::Pin;
46///
47/// use rocket::tokio::io;
48/// use rocket::data::{IoHandler, IoStream};
49///
50/// struct EchoHandler;
51///
52/// #[rocket::async_trait]
53/// impl IoHandler for EchoHandler {
54/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
55/// let (mut reader, mut writer) = io::split(io);
56/// io::copy(&mut reader, &mut writer).await?;
57/// Ok(())
58/// }
59/// }
60///
61/// # use rocket::Response;
62/// # rocket::async_test(async {
63/// # let mut response = Response::new();
64/// # response.add_upgrade("raw-echo", EchoHandler);
65/// # assert!(response.upgrade("raw-echo").is_some());
66/// # })
67/// ```
68#[crate::async_trait]
69pub trait IoHandler: Send {
70 /// Performs the raw I/O.
71 async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
72}
73
74#[crate::async_trait]
75impl IoHandler for () {
76 async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
77 Ok(())
78 }
79}
80
81#[doc(hidden)]
82impl From<Upgraded> for IoStream {
83 fn from(io: Upgraded) -> Self {
84 IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) }
85 }
86}
87
88/// A "trait alias" of sorts so we can use `AsyncRead + AsyncWrite + Unpin` in `dyn`.
89pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin { }
90
91/// Implemented for all `AsyncRead + AsyncWrite + Unpin`, of course.
92impl<T: AsyncRead + AsyncWrite + Unpin> AsyncReadWrite for T { }
93
94impl IoStream {
95 /// Returns the internal I/O stream.
96 fn inner_mut(&mut self) -> Pin<&mut dyn AsyncReadWrite> {
97 match self.kind {
98 IoStreamKind::Upgraded(ref mut io) => Pin::new(io),
99 }
100 }
101
102 /// Returns `true` if the inner I/O stream is write vectored.
103 fn inner_is_write_vectored(&self) -> bool {
104 match self.kind {
105 IoStreamKind::Upgraded(ref io) => io.is_write_vectored(),
106 }
107 }
108}
109
110impl AsyncRead for IoStream {
111 fn poll_read(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &mut ReadBuf<'_>,
115 ) -> Poll<io::Result<()>> {
116 self.get_mut().inner_mut().poll_read(cx, buf)
117 }
118}
119
120impl AsyncWrite for IoStream {
121 fn poll_write(
122 self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 buf: &[u8],
125 ) -> Poll<io::Result<usize>> {
126 self.get_mut().inner_mut().poll_write(cx, buf)
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130 self.get_mut().inner_mut().poll_flush(cx)
131 }
132
133 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134 self.get_mut().inner_mut().poll_shutdown(cx)
135 }
136
137 fn poll_write_vectored(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 bufs: &[io::IoSlice<'_>],
141 ) -> Poll<io::Result<usize>> {
142 self.get_mut().inner_mut().poll_write_vectored(cx, bufs)
143 }
144
145 fn is_write_vectored(&self) -> bool {
146 self.inner_is_write_vectored()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn is_unpin() {
156 fn check_traits<T: AsyncRead + AsyncWrite + Unpin + Send>() {}
157 check_traits::<IoStream>();
158 }
159}