rocket/data/
data_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::path::Path;
4use std::io::{self, Cursor};
5
6use tokio::fs::File;
7use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
8use futures::stream::Stream;
9use futures::ready;
10use yansi::Paint;
11
12use crate::http::hyper;
13use crate::ext::{PollExt, Chain};
14use crate::data::{Capped, N};
15
16/// Raw data stream of a request body.
17///
18/// This stream can only be obtained by calling
19/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
20/// contains all of the data in the body of the request.
21///
22/// Reading from a `DataStream` is accomplished via the various methods on the
23/// structure. In general, methods exists in two variants: those that _check_
24/// whether the entire stream was read and those that don't. The former either
25/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
26/// checking if the stream was read to completion while the latter do not.
27///
28/// | Read Into | Method                               | Notes                            |
29/// |-----------|--------------------------------------|----------------------------------|
30/// | `String`  | [`DataStream::into_string()`]        | Completeness checked. Preferred. |
31/// | `String`  | [`AsyncReadExt::read_to_string()`]   | Unchecked w/existing `String`.   |
32/// | `Vec<u8>` | [`DataStream::into_bytes()`]         | Checked. Preferred.              |
33/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`]  | Checked w/existing `Vec`.        |
34/// | `Vec<u8>` | [`DataStream::stream_precise_to()`]  | Unchecked w/existing `Vec`.      |
35/// | `File`    | [`DataStream::into_file()`]          | Checked. Preferred.              |
36/// | `File`    | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`.      |
37/// | `File`    | [`DataStream::stream_precise_to()`]  | Unchecked w/ existing `File`.    |
38/// | `T`       | [`DataStream::stream_to()`]          | Checked. Any `T: AsyncWrite`.    |
39/// | `T`       | [`DataStream::stream_precise_to()`]  | Unchecked. Any `T: AsyncWrite`.  |
40///
41/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
42/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
43pub struct DataStream<'r> {
44    pub(crate) chain: Take<Chain<Cursor<Vec<u8>>, StreamReader<'r>>>,
45}
46
47/// An adapter: turns a `T: Stream` (in `StreamKind`) into a `tokio::AsyncRead`.
48pub struct StreamReader<'r> {
49    state: State,
50    inner: StreamKind<'r>,
51}
52
53/// The current state of `StreamReader` `AsyncRead` adapter.
54enum State {
55    Pending,
56    Partial(Cursor<hyper::body::Bytes>),
57    Done,
58}
59
60/// The kinds of streams we accept as `Data`.
61enum StreamKind<'r> {
62    Empty,
63    Body(&'r mut hyper::Body),
64    Multipart(multer::Field<'r>)
65}
66
67impl<'r> DataStream<'r> {
68    pub(crate) fn new(buf: Vec<u8>, stream: StreamReader<'r>, limit: u64) -> Self {
69        let chain = Chain::new(Cursor::new(buf), stream).take(limit).into();
70        Self { chain }
71    }
72
73    /// Whether a previous read exhausted the set limit _and then some_.
74    async fn limit_exceeded(&mut self) -> io::Result<bool> {
75        #[cold]
76        async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result<bool> {
77            // Read one more byte after reaching limit to see if we cut early.
78            stream.chain.set_limit(1);
79            let mut buf = [0u8; 1];
80            Ok(stream.read(&mut buf).await? != 0)
81        }
82
83        Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?)
84    }
85
86    /// Number of bytes a full read from `self` will _definitely_ read.
87    ///
88    /// # Example
89    ///
90    /// ```rust
91    /// use rocket::data::{Data, ToByteUnit};
92    ///
93    /// async fn f(data: Data<'_>) {
94    ///     let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
95    /// }
96    /// ```
97    pub fn hint(&self) -> usize {
98        let buf_len = self.chain.get_ref().get_ref().0.get_ref().len();
99        std::cmp::min(buf_len, self.chain.limit() as usize)
100    }
101
102    /// A helper method to write the body of the request to any `AsyncWrite`
103    /// type. Returns an [`N`] which indicates how many bytes were written and
104    /// whether the entire stream was read. An additional read from `self` may
105    /// be required to check if all of the stream has been read. If that
106    /// information is not needed, use [`DataStream::stream_precise_to()`].
107    ///
108    /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
109    /// except in that it returns an `N` to check for completeness.
110    ///
111    /// # Example
112    ///
113    /// ```rust
114    /// use std::io;
115    /// use rocket::data::{Data, ToByteUnit};
116    ///
117    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
118    ///     // write all of the data to stdout
119    ///     let written = data.open(512.kibibytes())
120    ///         .stream_to(tokio::io::stdout()).await?;
121    ///
122    ///     Ok(format!("Wrote {} bytes.", written))
123    /// }
124    /// ```
125    #[inline(always)]
126    pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
127        where W: AsyncWrite + Unpin
128    {
129        let written = tokio::io::copy(&mut self, &mut writer).await?;
130        Ok(N { written, complete: !self.limit_exceeded().await? })
131    }
132
133    /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
134    /// conducted and thus read/write completeness is unknown.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// use std::io;
140    /// use rocket::data::{Data, ToByteUnit};
141    ///
142    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
143    ///     // write all of the data to stdout
144    ///     let written = data.open(512.kibibytes())
145    ///         .stream_precise_to(tokio::io::stdout()).await?;
146    ///
147    ///     Ok(format!("Wrote {} bytes.", written))
148    /// }
149    /// ```
150    #[inline(always)]
151    pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
152        where W: AsyncWrite + Unpin
153    {
154        tokio::io::copy(&mut self, &mut writer).await
155    }
156
157    /// A helper method to write the body of the request to a `Vec<u8>`.
158    ///
159    /// # Example
160    ///
161    /// ```rust
162    /// use std::io;
163    /// use rocket::data::{Data, ToByteUnit};
164    ///
165    /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
166    ///     let bytes = data.open(4.kibibytes()).into_bytes().await?;
167    ///     if !bytes.is_complete() {
168    ///         println!("there are bytes remaining in the stream");
169    ///     }
170    ///
171    ///     Ok(bytes.into_inner())
172    /// }
173    /// ```
174    pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
175        let mut vec = Vec::with_capacity(self.hint());
176        let n = self.stream_to(&mut vec).await?;
177        Ok(Capped { value: vec, n })
178    }
179
180    /// A helper method to write the body of the request to a `String`.
181    ///
182    /// # Example
183    ///
184    /// ```rust
185    /// use std::io;
186    /// use rocket::data::{Data, ToByteUnit};
187    ///
188    /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
189    ///     let string = data.open(10.bytes()).into_string().await?;
190    ///     if !string.is_complete() {
191    ///         println!("there are bytes remaining in the stream");
192    ///     }
193    ///
194    ///     Ok(string.into_inner())
195    /// }
196    /// ```
197    pub async fn into_string(mut self) -> io::Result<Capped<String>> {
198        let mut string = String::with_capacity(self.hint());
199        let written = self.read_to_string(&mut string).await?;
200        let n = N { written: written as u64, complete: !self.limit_exceeded().await? };
201        Ok(Capped { value: string, n })
202    }
203
204    /// A helper method to write the body of the request to a file at the path
205    /// determined by `path`. If a file at the path already exists, it is
206    /// overwritten. The opened file is returned.
207    ///
208    /// # Example
209    ///
210    /// ```rust
211    /// use std::io;
212    /// use rocket::data::{Data, ToByteUnit};
213    ///
214    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
215    ///     let file = data.open(1.megabytes()).into_file("/static/file").await?;
216    ///     if !file.is_complete() {
217    ///         println!("there are bytes remaining in the stream");
218    ///     }
219    ///
220    ///     Ok(format!("Wrote {} bytes to /static/file", file.n))
221    /// }
222    /// ```
223    pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
224        let mut file = File::create(path).await?;
225        let n = self.stream_to(&mut tokio::io::BufWriter::new(&mut file)).await?;
226        Ok(Capped { value: file, n })
227    }
228}
229
230// TODO.async: Consider implementing `AsyncBufRead`.
231
232impl StreamReader<'_> {
233    pub fn empty() -> Self {
234        Self { inner: StreamKind::Empty, state: State::Done }
235    }
236}
237
238impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> {
239    fn from(body: &'r mut hyper::Body) -> Self {
240        Self { inner: StreamKind::Body(body), state: State::Pending }
241    }
242}
243
244impl<'r> From<multer::Field<'r>> for StreamReader<'r> {
245    fn from(field: multer::Field<'r>) -> Self {
246        Self { inner: StreamKind::Multipart(field), state: State::Pending }
247    }
248}
249
250impl AsyncRead for DataStream<'_> {
251    #[inline(always)]
252    fn poll_read(
253        mut self: Pin<&mut Self>,
254        cx: &mut Context<'_>,
255        buf: &mut ReadBuf<'_>,
256    ) -> Poll<io::Result<()>> {
257        if self.chain.limit() == 0 {
258            let stream: &StreamReader<'_> = &self.chain.get_ref().get_ref().1;
259            let kind = match stream.inner {
260                StreamKind::Empty => "an empty stream (vacuous)",
261                StreamKind::Body(_) => "the request body",
262                StreamKind::Multipart(_) => "a multipart form field",
263            };
264
265            warn_!("Data limit reached while reading {}.", kind.primary().bold());
266        }
267
268        Pin::new(&mut self.chain).poll_read(cx, buf)
269    }
270}
271
272impl Stream for StreamKind<'_> {
273    type Item = io::Result<hyper::body::Bytes>;
274
275    fn poll_next(
276        self: Pin<&mut Self>,
277        cx: &mut Context<'_>,
278    ) -> Poll<Option<Self::Item>> {
279        match self.get_mut() {
280            StreamKind::Body(body) => Pin::new(body).poll_next(cx)
281                .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
282            StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx)
283                .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
284            StreamKind::Empty => Poll::Ready(None),
285        }
286    }
287
288    fn size_hint(&self) -> (usize, Option<usize>) {
289        match self {
290            StreamKind::Body(body) => body.size_hint(),
291            StreamKind::Multipart(mp) => mp.size_hint(),
292            StreamKind::Empty => (0, Some(0)),
293        }
294    }
295}
296
297impl AsyncRead for StreamReader<'_> {
298    fn poll_read(
299        mut self: Pin<&mut Self>,
300        cx: &mut Context<'_>,
301        buf: &mut ReadBuf<'_>,
302    ) -> Poll<io::Result<()>> {
303        loop {
304            self.state = match self.state {
305                State::Pending => {
306                    match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
307                        Some(Err(e)) => return Poll::Ready(Err(e)),
308                        Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)),
309                        None => State::Done,
310                    }
311                },
312                State::Partial(ref mut cursor) => {
313                    let rem = buf.remaining();
314                    match ready!(Pin::new(cursor).poll_read(cx, buf)) {
315                        Ok(()) if rem == buf.remaining() => State::Pending,
316                        result => return Poll::Ready(result),
317                    }
318                }
319                State::Done => return Poll::Ready(Ok(())),
320            }
321        }
322    }
323}