rocket/data/data_stream.rs
1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::path::Path;
4use std::io::{self, Cursor};
5
6use futures::ready;
7use futures::stream::Stream;
8use tokio::fs::File;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
10use tokio_util::io::StreamReader;
11use hyper::body::{Body, Bytes, Incoming as HyperBody};
12
13use crate::data::{Capped, N};
14use crate::data::transform::Transform;
15use crate::util::Chain;
16
17use super::peekable::Peekable;
18use super::transform::TransformBuf;
19
20/// Raw data stream of a request body.
21///
22/// This stream can only be obtained by calling
23/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
24/// contains all of the data in the body of the request.
25///
26/// Reading from a `DataStream` is accomplished via the various methods on the
27/// structure. In general, methods exists in two variants: those that _check_
28/// whether the entire stream was read and those that don't. The former either
29/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
30/// checking if the stream was read to completion while the latter do not.
31///
32/// | Read Into | Method | Notes |
33/// |-----------|--------------------------------------|----------------------------------|
34/// | `String` | [`DataStream::into_string()`] | Completeness checked. Preferred. |
35/// | `String` | [`AsyncReadExt::read_to_string()`] | Unchecked w/existing `String`. |
36/// | `Vec<u8>` | [`DataStream::into_bytes()`] | Checked. Preferred. |
37/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`] | Checked w/existing `Vec`. |
38/// | `Vec<u8>` | [`DataStream::stream_precise_to()`] | Unchecked w/existing `Vec`. |
39/// | `File` | [`DataStream::into_file()`] | Checked. Preferred. |
40/// | `File` | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`. |
41/// | `File` | [`DataStream::stream_precise_to()`] | Unchecked w/ existing `File`. |
42/// | `T` | [`DataStream::stream_to()`] | Checked. Any `T: AsyncWrite`. |
43/// | `T` | [`DataStream::stream_precise_to()`] | Unchecked. Any `T: AsyncWrite`. |
44///
45/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
46/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
47#[non_exhaustive]
48pub enum DataStream<'r> {
49 #[doc(hidden)]
50 Base(BaseReader<'r>),
51 #[doc(hidden)]
52 Transform(TransformReader<'r>),
53}
54
55/// A data stream that has a `transformer` applied to it.
56pub struct TransformReader<'r> {
57 transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
58 stream: Pin<Box<DataStream<'r>>>,
59 inner_done: bool,
60}
61
62/// Limited, pre-buffered reader to the underlying data stream.
63pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
64
65/// Direct reader to the underlying data stream. Not limited in any manner.
66pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
67
68/// Raw underlying data stream.
69pub enum RawStream<'r> {
70 Empty,
71 Body(HyperBody),
72 #[cfg(feature = "http3-preview")]
73 H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
74 Multipart(multer::Field<'r>),
75}
76
77impl<'r> TransformReader<'r> {
78 /// Returns the underlying `BaseReader`.
79 fn base_mut(&mut self) -> &mut BaseReader<'r> {
80 match self.stream.as_mut().get_mut() {
81 DataStream::Base(base) => base,
82 DataStream::Transform(inner) => inner.base_mut(),
83 }
84 }
85
86 /// Returns the underlying `BaseReader`.
87 fn base(&self) -> &BaseReader<'r> {
88 match self.stream.as_ref().get_ref() {
89 DataStream::Base(base) => base,
90 DataStream::Transform(inner) => inner.base(),
91 }
92 }
93}
94
95impl<'r> DataStream<'r> {
96 pub(crate) fn new(
97 transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
98 Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
99 limit: u64
100 ) -> Self {
101 let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
102 for transformer in transformers {
103 stream = DataStream::Transform(TransformReader {
104 transformer,
105 stream: Box::pin(stream),
106 inner_done: false,
107 });
108 }
109
110 stream
111 }
112
113 /// Returns the underlying `BaseReader`.
114 fn base_mut(&mut self) -> &mut BaseReader<'r> {
115 match self {
116 DataStream::Base(base) => base,
117 DataStream::Transform(transform) => transform.base_mut(),
118 }
119 }
120
121 /// Returns the underlying `BaseReader`.
122 fn base(&self) -> &BaseReader<'r> {
123 match self {
124 DataStream::Base(base) => base,
125 DataStream::Transform(transform) => transform.base(),
126 }
127 }
128
129 /// Whether a previous read exhausted the set limit _and then some_.
130 async fn limit_exceeded(&mut self) -> io::Result<bool> {
131 let base = self.base_mut();
132
133 #[cold]
134 async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
135 // Read one more byte after reaching limit to see if we cut early.
136 base.set_limit(1);
137 let mut buf = [0u8; 1];
138 let exceeded = base.read(&mut buf).await? != 0;
139 base.set_limit(0);
140 Ok(exceeded)
141 }
142
143 Ok(base.limit() == 0 && _limit_exceeded(base).await?)
144 }
145
146 /// Number of bytes a full read from `self` will _definitely_ read.
147 ///
148 /// # Example
149 ///
150 /// ```rust
151 /// use rocket::data::{Data, ToByteUnit};
152 ///
153 /// async fn f(data: Data<'_>) {
154 /// let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
155 /// }
156 /// ```
157 pub fn hint(&self) -> usize {
158 let base = self.base();
159 if let (Some(cursor), _) = base.get_ref().get_ref() {
160 let len = cursor.get_ref().len() as u64;
161 let position = cursor.position().min(len);
162 let remaining = len - position;
163 remaining.min(base.limit()) as usize
164 } else {
165 0
166 }
167 }
168
169 /// A helper method to write the body of the request to any `AsyncWrite`
170 /// type. Returns an [`N`] which indicates how many bytes were written and
171 /// whether the entire stream was read. An additional read from `self` may
172 /// be required to check if all of the stream has been read. If that
173 /// information is not needed, use [`DataStream::stream_precise_to()`].
174 ///
175 /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
176 /// except in that it returns an `N` to check for completeness.
177 ///
178 /// # Example
179 ///
180 /// ```rust
181 /// use std::io;
182 /// use rocket::data::{Data, ToByteUnit};
183 ///
184 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
185 /// // write all of the data to stdout
186 /// let written = data.open(512.kibibytes())
187 /// .stream_to(tokio::io::stdout()).await?;
188 ///
189 /// Ok(format!("Wrote {} bytes.", written))
190 /// }
191 /// ```
192 #[inline(always)]
193 pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
194 where W: AsyncWrite + Unpin
195 {
196 let written = tokio::io::copy(&mut self, &mut writer).await?;
197 Ok(N { written, complete: !self.limit_exceeded().await? })
198 }
199
200 /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
201 /// conducted and thus read/write completeness is unknown.
202 ///
203 /// # Example
204 ///
205 /// ```rust
206 /// use std::io;
207 /// use rocket::data::{Data, ToByteUnit};
208 ///
209 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
210 /// // write all of the data to stdout
211 /// let written = data.open(512.kibibytes())
212 /// .stream_precise_to(tokio::io::stdout()).await?;
213 ///
214 /// Ok(format!("Wrote {} bytes.", written))
215 /// }
216 /// ```
217 #[inline(always)]
218 pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
219 where W: AsyncWrite + Unpin
220 {
221 tokio::io::copy(&mut self, &mut writer).await
222 }
223
224 /// A helper method to write the body of the request to a `Vec<u8>`.
225 ///
226 /// # Example
227 ///
228 /// ```rust
229 /// use std::io;
230 /// use rocket::data::{Data, ToByteUnit};
231 ///
232 /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
233 /// let bytes = data.open(4.kibibytes()).into_bytes().await?;
234 /// if !bytes.is_complete() {
235 /// println!("there are bytes remaining in the stream");
236 /// }
237 ///
238 /// Ok(bytes.into_inner())
239 /// }
240 /// ```
241 pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
242 let mut vec = Vec::with_capacity(self.hint());
243 let n = self.stream_to(&mut vec).await?;
244 Ok(Capped { value: vec, n })
245 }
246
247 /// A helper method to write the body of the request to a `String`.
248 ///
249 /// # Example
250 ///
251 /// ```rust
252 /// use std::io;
253 /// use rocket::data::{Data, ToByteUnit};
254 ///
255 /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
256 /// let string = data.open(10.bytes()).into_string().await?;
257 /// if !string.is_complete() {
258 /// println!("there are bytes remaining in the stream");
259 /// }
260 ///
261 /// Ok(string.into_inner())
262 /// }
263 /// ```
264 pub async fn into_string(mut self) -> io::Result<Capped<String>> {
265 let mut string = String::with_capacity(self.hint());
266 let written = self.read_to_string(&mut string).await?;
267 let n = N { written: written as u64, complete: !self.limit_exceeded().await? };
268 Ok(Capped { value: string, n })
269 }
270
271 /// A helper method to write the body of the request to a file at the path
272 /// determined by `path`. If a file at the path already exists, it is
273 /// overwritten. The opened file is returned.
274 ///
275 /// # Example
276 ///
277 /// ```rust
278 /// use std::io;
279 /// use rocket::data::{Data, ToByteUnit};
280 ///
281 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
282 /// let file = data.open(1.megabytes()).into_file("/static/file").await?;
283 /// if !file.is_complete() {
284 /// println!("there are bytes remaining in the stream");
285 /// }
286 ///
287 /// Ok(format!("Wrote {} bytes to /static/file", file.n))
288 /// }
289 /// ```
290 pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
291 let mut file = File::create(path).await?;
292 let n = self.stream_to(&mut tokio::io::BufWriter::new(&mut file)).await?;
293 Ok(Capped { value: file, n })
294 }
295}
296
297impl AsyncRead for DataStream<'_> {
298 fn poll_read(
299 self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 buf: &mut ReadBuf<'_>,
302 ) -> Poll<io::Result<()>> {
303 match self.get_mut() {
304 DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
305 DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
306 }
307 }
308}
309
310impl AsyncRead for TransformReader<'_> {
311 fn poll_read(
312 mut self: Pin<&mut Self>,
313 cx: &mut Context<'_>,
314 buf: &mut ReadBuf<'_>,
315 ) -> Poll<io::Result<()>> {
316 let init_fill = buf.filled().len();
317 if !self.inner_done {
318 ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
319 self.inner_done = init_fill == buf.filled().len();
320 }
321
322 if self.inner_done {
323 return self.transformer.as_mut().poll_finish(cx, buf);
324 }
325
326 let mut tbuf = TransformBuf { buf, cursor: init_fill };
327 self.transformer.as_mut().transform(&mut tbuf)?;
328 if buf.filled().len() == init_fill {
329 cx.waker().wake_by_ref();
330 return Poll::Pending;
331 }
332
333 Poll::Ready(Ok(()))
334 }
335}
336
337impl Stream for RawStream<'_> {
338 type Item = io::Result<Bytes>;
339
340 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341 match self.get_mut() {
342 // TODO: Expose trailer headers, somehow.
343 RawStream::Body(body) => {
344 Pin::new(body)
345 .poll_frame(cx)
346 .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
347 .map_err(io::Error::other)
348 },
349 #[cfg(feature = "http3-preview")]
350 RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
351 RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
352 RawStream::Empty => Poll::Ready(None),
353 }
354 }
355
356 fn size_hint(&self) -> (usize, Option<usize>) {
357 match self {
358 RawStream::Body(body) => {
359 let hint = body.size_hint();
360 let (lower, upper) = (hint.lower(), hint.upper());
361 (lower as usize, upper.map(|x| x as usize))
362 },
363 #[cfg(feature = "http3-preview")]
364 RawStream::H3Body(_) => (0, Some(0)),
365 RawStream::Multipart(mp) => mp.size_hint(),
366 RawStream::Empty => (0, Some(0)),
367 }
368 }
369}
370
371impl std::fmt::Display for RawStream<'_> {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 match self {
374 RawStream::Empty => f.write_str("empty stream"),
375 RawStream::Body(_) => f.write_str("request body"),
376 #[cfg(feature = "http3-preview")]
377 RawStream::H3Body(_) => f.write_str("http3 quic stream"),
378 RawStream::Multipart(_) => f.write_str("multipart form field"),
379 }
380 }
381}
382
383impl<'r> From<HyperBody> for RawStream<'r> {
384 fn from(value: HyperBody) -> Self {
385 Self::Body(value)
386 }
387}
388
389#[cfg(feature = "http3-preview")]
390impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
391 fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
392 Self::H3Body(value)
393 }
394}
395
396impl<'r> From<multer::Field<'r>> for RawStream<'r> {
397 fn from(value: multer::Field<'r>) -> Self {
398 Self::Multipart(value)
399 }
400}