rocket/data/
data_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::path::Path;
4use std::io::{self, Cursor};
5
6use futures::ready;
7use futures::stream::Stream;
8use tokio::fs::File;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
10use tokio_util::io::StreamReader;
11use hyper::body::{Body, Bytes, Incoming as HyperBody};
12
13use crate::data::{Capped, N};
14use crate::data::transform::Transform;
15use crate::util::Chain;
16
17use super::peekable::Peekable;
18use super::transform::TransformBuf;
19
20/// Raw data stream of a request body.
21///
22/// This stream can only be obtained by calling
23/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
24/// contains all of the data in the body of the request.
25///
26/// Reading from a `DataStream` is accomplished via the various methods on the
27/// structure. In general, methods exists in two variants: those that _check_
28/// whether the entire stream was read and those that don't. The former either
29/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
30/// checking if the stream was read to completion while the latter do not.
31///
32/// | Read Into | Method                               | Notes                            |
33/// |-----------|--------------------------------------|----------------------------------|
34/// | `String`  | [`DataStream::into_string()`]        | Completeness checked. Preferred. |
35/// | `String`  | [`AsyncReadExt::read_to_string()`]   | Unchecked w/existing `String`.   |
36/// | `Vec<u8>` | [`DataStream::into_bytes()`]         | Checked. Preferred.              |
37/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`]  | Checked w/existing `Vec`.        |
38/// | `Vec<u8>` | [`DataStream::stream_precise_to()`]  | Unchecked w/existing `Vec`.      |
39/// | `File`    | [`DataStream::into_file()`]          | Checked. Preferred.              |
40/// | `File`    | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`.      |
41/// | `File`    | [`DataStream::stream_precise_to()`]  | Unchecked w/ existing `File`.    |
42/// | `T`       | [`DataStream::stream_to()`]          | Checked. Any `T: AsyncWrite`.    |
43/// | `T`       | [`DataStream::stream_precise_to()`]  | Unchecked. Any `T: AsyncWrite`.  |
44///
45/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
46/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
47#[non_exhaustive]
48pub enum DataStream<'r> {
49    #[doc(hidden)]
50    Base(BaseReader<'r>),
51    #[doc(hidden)]
52    Transform(TransformReader<'r>),
53}
54
55/// A data stream that has a `transformer` applied to it.
56pub struct TransformReader<'r> {
57    transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
58    stream: Pin<Box<DataStream<'r>>>,
59    inner_done: bool,
60}
61
62/// Limited, pre-buffered reader to the underlying data stream.
63pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
64
65/// Direct reader to the underlying data stream. Not limited in any manner.
66pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
67
68/// Raw underlying data stream.
69pub enum RawStream<'r> {
70    Empty,
71    Body(HyperBody),
72    #[cfg(feature = "http3-preview")]
73    H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
74    Multipart(multer::Field<'r>),
75}
76
77impl<'r> TransformReader<'r> {
78    /// Returns the underlying `BaseReader`.
79    fn base_mut(&mut self) -> &mut BaseReader<'r> {
80        match self.stream.as_mut().get_mut() {
81            DataStream::Base(base) => base,
82            DataStream::Transform(inner) => inner.base_mut(),
83        }
84    }
85
86    /// Returns the underlying `BaseReader`.
87    fn base(&self) -> &BaseReader<'r> {
88        match self.stream.as_ref().get_ref() {
89            DataStream::Base(base) => base,
90            DataStream::Transform(inner) => inner.base(),
91        }
92    }
93}
94
95impl<'r> DataStream<'r> {
96    pub(crate) fn new(
97        transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
98        Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
99        limit: u64
100    ) -> Self {
101        let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
102        for transformer in transformers {
103            stream = DataStream::Transform(TransformReader {
104                transformer,
105                stream: Box::pin(stream),
106                inner_done: false,
107            });
108        }
109
110        stream
111    }
112
113    /// Returns the underlying `BaseReader`.
114    fn base_mut(&mut self) -> &mut BaseReader<'r> {
115        match self {
116            DataStream::Base(base) => base,
117            DataStream::Transform(transform) => transform.base_mut(),
118        }
119    }
120
121    /// Returns the underlying `BaseReader`.
122    fn base(&self) -> &BaseReader<'r> {
123        match self {
124            DataStream::Base(base) => base,
125            DataStream::Transform(transform) => transform.base(),
126        }
127    }
128
129    /// Whether a previous read exhausted the set limit _and then some_.
130    async fn limit_exceeded(&mut self) -> io::Result<bool> {
131        let base = self.base_mut();
132
133        #[cold]
134        async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
135            // Read one more byte after reaching limit to see if we cut early.
136            base.set_limit(1);
137            let mut buf = [0u8; 1];
138            let exceeded = base.read(&mut buf).await? != 0;
139            base.set_limit(0);
140            Ok(exceeded)
141        }
142
143        Ok(base.limit() == 0 && _limit_exceeded(base).await?)
144    }
145
146    /// Number of bytes a full read from `self` will _definitely_ read.
147    ///
148    /// # Example
149    ///
150    /// ```rust
151    /// use rocket::data::{Data, ToByteUnit};
152    ///
153    /// async fn f(data: Data<'_>) {
154    ///     let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
155    /// }
156    /// ```
157    pub fn hint(&self) -> usize {
158        let base = self.base();
159        if let (Some(cursor), _) = base.get_ref().get_ref() {
160            let len = cursor.get_ref().len() as u64;
161            let position = cursor.position().min(len);
162            let remaining = len - position;
163            remaining.min(base.limit()) as usize
164        } else {
165            0
166        }
167    }
168
169    /// A helper method to write the body of the request to any `AsyncWrite`
170    /// type. Returns an [`N`] which indicates how many bytes were written and
171    /// whether the entire stream was read. An additional read from `self` may
172    /// be required to check if all of the stream has been read. If that
173    /// information is not needed, use [`DataStream::stream_precise_to()`].
174    ///
175    /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
176    /// except in that it returns an `N` to check for completeness.
177    ///
178    /// # Example
179    ///
180    /// ```rust
181    /// use std::io;
182    /// use rocket::data::{Data, ToByteUnit};
183    ///
184    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
185    ///     // write all of the data to stdout
186    ///     let written = data.open(512.kibibytes())
187    ///         .stream_to(tokio::io::stdout()).await?;
188    ///
189    ///     Ok(format!("Wrote {} bytes.", written))
190    /// }
191    /// ```
192    #[inline(always)]
193    pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
194        where W: AsyncWrite + Unpin
195    {
196        let written = tokio::io::copy(&mut self, &mut writer).await?;
197        Ok(N { written, complete: !self.limit_exceeded().await? })
198    }
199
200    /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
201    /// conducted and thus read/write completeness is unknown.
202    ///
203    /// # Example
204    ///
205    /// ```rust
206    /// use std::io;
207    /// use rocket::data::{Data, ToByteUnit};
208    ///
209    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
210    ///     // write all of the data to stdout
211    ///     let written = data.open(512.kibibytes())
212    ///         .stream_precise_to(tokio::io::stdout()).await?;
213    ///
214    ///     Ok(format!("Wrote {} bytes.", written))
215    /// }
216    /// ```
217    #[inline(always)]
218    pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
219        where W: AsyncWrite + Unpin
220    {
221        tokio::io::copy(&mut self, &mut writer).await
222    }
223
224    /// A helper method to write the body of the request to a `Vec<u8>`.
225    ///
226    /// # Example
227    ///
228    /// ```rust
229    /// use std::io;
230    /// use rocket::data::{Data, ToByteUnit};
231    ///
232    /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
233    ///     let bytes = data.open(4.kibibytes()).into_bytes().await?;
234    ///     if !bytes.is_complete() {
235    ///         println!("there are bytes remaining in the stream");
236    ///     }
237    ///
238    ///     Ok(bytes.into_inner())
239    /// }
240    /// ```
241    pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
242        let mut vec = Vec::with_capacity(self.hint());
243        let n = self.stream_to(&mut vec).await?;
244        Ok(Capped { value: vec, n })
245    }
246
247    /// A helper method to write the body of the request to a `String`.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use std::io;
253    /// use rocket::data::{Data, ToByteUnit};
254    ///
255    /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
256    ///     let string = data.open(10.bytes()).into_string().await?;
257    ///     if !string.is_complete() {
258    ///         println!("there are bytes remaining in the stream");
259    ///     }
260    ///
261    ///     Ok(string.into_inner())
262    /// }
263    /// ```
264    pub async fn into_string(mut self) -> io::Result<Capped<String>> {
265        let mut string = String::with_capacity(self.hint());
266        let written = self.read_to_string(&mut string).await?;
267        let n = N { written: written as u64, complete: !self.limit_exceeded().await? };
268        Ok(Capped { value: string, n })
269    }
270
271    /// A helper method to write the body of the request to a file at the path
272    /// determined by `path`. If a file at the path already exists, it is
273    /// overwritten. The opened file is returned.
274    ///
275    /// # Example
276    ///
277    /// ```rust
278    /// use std::io;
279    /// use rocket::data::{Data, ToByteUnit};
280    ///
281    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
282    ///     let file = data.open(1.megabytes()).into_file("/static/file").await?;
283    ///     if !file.is_complete() {
284    ///         println!("there are bytes remaining in the stream");
285    ///     }
286    ///
287    ///     Ok(format!("Wrote {} bytes to /static/file", file.n))
288    /// }
289    /// ```
290    pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
291        let mut file = File::create(path).await?;
292        let n = self.stream_to(&mut tokio::io::BufWriter::new(&mut file)).await?;
293        Ok(Capped { value: file, n })
294    }
295}
296
297impl AsyncRead for DataStream<'_> {
298    fn poll_read(
299        self: Pin<&mut Self>,
300        cx: &mut Context<'_>,
301        buf: &mut ReadBuf<'_>,
302    ) -> Poll<io::Result<()>> {
303        match self.get_mut() {
304            DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
305            DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
306        }
307    }
308}
309
310impl AsyncRead for TransformReader<'_> {
311    fn poll_read(
312        mut self: Pin<&mut Self>,
313        cx: &mut Context<'_>,
314        buf: &mut ReadBuf<'_>,
315    ) -> Poll<io::Result<()>> {
316        let init_fill = buf.filled().len();
317        if !self.inner_done {
318            ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
319            self.inner_done = init_fill == buf.filled().len();
320        }
321
322        if self.inner_done {
323            return self.transformer.as_mut().poll_finish(cx, buf);
324        }
325
326        let mut tbuf = TransformBuf { buf, cursor: init_fill };
327        self.transformer.as_mut().transform(&mut tbuf)?;
328        if buf.filled().len() == init_fill {
329            cx.waker().wake_by_ref();
330            return Poll::Pending;
331        }
332
333        Poll::Ready(Ok(()))
334    }
335}
336
337impl Stream for RawStream<'_> {
338    type Item = io::Result<Bytes>;
339
340    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341        match self.get_mut() {
342            // TODO: Expose trailer headers, somehow.
343            RawStream::Body(body) => {
344                Pin::new(body)
345                    .poll_frame(cx)
346                    .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
347                    .map_err(io::Error::other)
348            },
349            #[cfg(feature = "http3-preview")]
350            RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
351            RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
352            RawStream::Empty => Poll::Ready(None),
353        }
354    }
355
356    fn size_hint(&self) -> (usize, Option<usize>) {
357        match self {
358            RawStream::Body(body) => {
359                let hint = body.size_hint();
360                let (lower, upper) = (hint.lower(), hint.upper());
361                (lower as usize, upper.map(|x| x as usize))
362            },
363            #[cfg(feature = "http3-preview")]
364            RawStream::H3Body(_) => (0, Some(0)),
365            RawStream::Multipart(mp) => mp.size_hint(),
366            RawStream::Empty => (0, Some(0)),
367        }
368    }
369}
370
371impl std::fmt::Display for RawStream<'_> {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match self {
374            RawStream::Empty => f.write_str("empty stream"),
375            RawStream::Body(_) => f.write_str("request body"),
376            #[cfg(feature = "http3-preview")]
377            RawStream::H3Body(_) => f.write_str("http3 quic stream"),
378            RawStream::Multipart(_) => f.write_str("multipart form field"),
379        }
380    }
381}
382
383impl<'r> From<HyperBody> for RawStream<'r> {
384    fn from(value: HyperBody) -> Self {
385        Self::Body(value)
386    }
387}
388
389#[cfg(feature = "http3-preview")]
390impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
391    fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
392        Self::H3Body(value)
393    }
394}
395
396impl<'r> From<multer::Field<'r>> for RawStream<'r> {
397    fn from(value: multer::Field<'r>) -> Self {
398        Self::Multipart(value)
399    }
400}