rocket/data/data_stream.rs
1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::path::Path;
4use std::io::{self, Cursor};
5
6use tokio::fs::File;
7use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
8use futures::stream::Stream;
9use futures::ready;
10use yansi::Paint;
11
12use crate::http::hyper;
13use crate::ext::{PollExt, Chain};
14use crate::data::{Capped, N};
15
16/// Raw data stream of a request body.
17///
18/// This stream can only be obtained by calling
19/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
20/// contains all of the data in the body of the request.
21///
22/// Reading from a `DataStream` is accomplished via the various methods on the
23/// structure. In general, methods exists in two variants: those that _check_
24/// whether the entire stream was read and those that don't. The former either
25/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
26/// checking if the stream was read to completion while the latter do not.
27///
28/// | Read Into | Method | Notes |
29/// |-----------|--------------------------------------|----------------------------------|
30/// | `String` | [`DataStream::into_string()`] | Completeness checked. Preferred. |
31/// | `String` | [`AsyncReadExt::read_to_string()`] | Unchecked w/existing `String`. |
32/// | `Vec<u8>` | [`DataStream::into_bytes()`] | Checked. Preferred. |
33/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`] | Checked w/existing `Vec`. |
34/// | `Vec<u8>` | [`DataStream::stream_precise_to()`] | Unchecked w/existing `Vec`. |
35/// | `File` | [`DataStream::into_file()`] | Checked. Preferred. |
36/// | `File` | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`. |
37/// | `File` | [`DataStream::stream_precise_to()`] | Unchecked w/ existing `File`. |
38/// | `T` | [`DataStream::stream_to()`] | Checked. Any `T: AsyncWrite`. |
39/// | `T` | [`DataStream::stream_precise_to()`] | Unchecked. Any `T: AsyncWrite`. |
40///
41/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
42/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
43pub struct DataStream<'r> {
44 pub(crate) chain: Take<Chain<Cursor<Vec<u8>>, StreamReader<'r>>>,
45}
46
47/// An adapter: turns a `T: Stream` (in `StreamKind`) into a `tokio::AsyncRead`.
48pub struct StreamReader<'r> {
49 state: State,
50 inner: StreamKind<'r>,
51}
52
53/// The current state of `StreamReader` `AsyncRead` adapter.
54enum State {
55 Pending,
56 Partial(Cursor<hyper::body::Bytes>),
57 Done,
58}
59
60/// The kinds of streams we accept as `Data`.
61enum StreamKind<'r> {
62 Empty,
63 Body(&'r mut hyper::Body),
64 Multipart(multer::Field<'r>)
65}
66
67impl<'r> DataStream<'r> {
68 pub(crate) fn new(buf: Vec<u8>, stream: StreamReader<'r>, limit: u64) -> Self {
69 let chain = Chain::new(Cursor::new(buf), stream).take(limit).into();
70 Self { chain }
71 }
72
73 /// Whether a previous read exhausted the set limit _and then some_.
74 async fn limit_exceeded(&mut self) -> io::Result<bool> {
75 #[cold]
76 async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result<bool> {
77 // Read one more byte after reaching limit to see if we cut early.
78 stream.chain.set_limit(1);
79 let mut buf = [0u8; 1];
80 Ok(stream.read(&mut buf).await? != 0)
81 }
82
83 Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?)
84 }
85
86 /// Number of bytes a full read from `self` will _definitely_ read.
87 ///
88 /// # Example
89 ///
90 /// ```rust
91 /// use rocket::data::{Data, ToByteUnit};
92 ///
93 /// async fn f(data: Data<'_>) {
94 /// let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
95 /// }
96 /// ```
97 pub fn hint(&self) -> usize {
98 let buf_len = self.chain.get_ref().get_ref().0.get_ref().len();
99 std::cmp::min(buf_len, self.chain.limit() as usize)
100 }
101
102 /// A helper method to write the body of the request to any `AsyncWrite`
103 /// type. Returns an [`N`] which indicates how many bytes were written and
104 /// whether the entire stream was read. An additional read from `self` may
105 /// be required to check if all of the stream has been read. If that
106 /// information is not needed, use [`DataStream::stream_precise_to()`].
107 ///
108 /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
109 /// except in that it returns an `N` to check for completeness.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use std::io;
115 /// use rocket::data::{Data, ToByteUnit};
116 ///
117 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
118 /// // write all of the data to stdout
119 /// let written = data.open(512.kibibytes())
120 /// .stream_to(tokio::io::stdout()).await?;
121 ///
122 /// Ok(format!("Wrote {} bytes.", written))
123 /// }
124 /// ```
125 #[inline(always)]
126 pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
127 where W: AsyncWrite + Unpin
128 {
129 let written = tokio::io::copy(&mut self, &mut writer).await?;
130 Ok(N { written, complete: !self.limit_exceeded().await? })
131 }
132
133 /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
134 /// conducted and thus read/write completeness is unknown.
135 ///
136 /// # Example
137 ///
138 /// ```rust
139 /// use std::io;
140 /// use rocket::data::{Data, ToByteUnit};
141 ///
142 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
143 /// // write all of the data to stdout
144 /// let written = data.open(512.kibibytes())
145 /// .stream_precise_to(tokio::io::stdout()).await?;
146 ///
147 /// Ok(format!("Wrote {} bytes.", written))
148 /// }
149 /// ```
150 #[inline(always)]
151 pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
152 where W: AsyncWrite + Unpin
153 {
154 tokio::io::copy(&mut self, &mut writer).await
155 }
156
157 /// A helper method to write the body of the request to a `Vec<u8>`.
158 ///
159 /// # Example
160 ///
161 /// ```rust
162 /// use std::io;
163 /// use rocket::data::{Data, ToByteUnit};
164 ///
165 /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
166 /// let bytes = data.open(4.kibibytes()).into_bytes().await?;
167 /// if !bytes.is_complete() {
168 /// println!("there are bytes remaining in the stream");
169 /// }
170 ///
171 /// Ok(bytes.into_inner())
172 /// }
173 /// ```
174 pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
175 let mut vec = Vec::with_capacity(self.hint());
176 let n = self.stream_to(&mut vec).await?;
177 Ok(Capped { value: vec, n })
178 }
179
180 /// A helper method to write the body of the request to a `String`.
181 ///
182 /// # Example
183 ///
184 /// ```rust
185 /// use std::io;
186 /// use rocket::data::{Data, ToByteUnit};
187 ///
188 /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
189 /// let string = data.open(10.bytes()).into_string().await?;
190 /// if !string.is_complete() {
191 /// println!("there are bytes remaining in the stream");
192 /// }
193 ///
194 /// Ok(string.into_inner())
195 /// }
196 /// ```
197 pub async fn into_string(mut self) -> io::Result<Capped<String>> {
198 let mut string = String::with_capacity(self.hint());
199 let written = self.read_to_string(&mut string).await?;
200 let n = N { written: written as u64, complete: !self.limit_exceeded().await? };
201 Ok(Capped { value: string, n })
202 }
203
204 /// A helper method to write the body of the request to a file at the path
205 /// determined by `path`. If a file at the path already exists, it is
206 /// overwritten. The opened file is returned.
207 ///
208 /// # Example
209 ///
210 /// ```rust
211 /// use std::io;
212 /// use rocket::data::{Data, ToByteUnit};
213 ///
214 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
215 /// let file = data.open(1.megabytes()).into_file("/static/file").await?;
216 /// if !file.is_complete() {
217 /// println!("there are bytes remaining in the stream");
218 /// }
219 ///
220 /// Ok(format!("Wrote {} bytes to /static/file", file.n))
221 /// }
222 /// ```
223 pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
224 let mut file = File::create(path).await?;
225 let n = self.stream_to(&mut tokio::io::BufWriter::new(&mut file)).await?;
226 Ok(Capped { value: file, n })
227 }
228}
229
230// TODO.async: Consider implementing `AsyncBufRead`.
231
232impl StreamReader<'_> {
233 pub fn empty() -> Self {
234 Self { inner: StreamKind::Empty, state: State::Done }
235 }
236}
237
238impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> {
239 fn from(body: &'r mut hyper::Body) -> Self {
240 Self { inner: StreamKind::Body(body), state: State::Pending }
241 }
242}
243
244impl<'r> From<multer::Field<'r>> for StreamReader<'r> {
245 fn from(field: multer::Field<'r>) -> Self {
246 Self { inner: StreamKind::Multipart(field), state: State::Pending }
247 }
248}
249
250impl AsyncRead for DataStream<'_> {
251 #[inline(always)]
252 fn poll_read(
253 mut self: Pin<&mut Self>,
254 cx: &mut Context<'_>,
255 buf: &mut ReadBuf<'_>,
256 ) -> Poll<io::Result<()>> {
257 if self.chain.limit() == 0 {
258 let stream: &StreamReader<'_> = &self.chain.get_ref().get_ref().1;
259 let kind = match stream.inner {
260 StreamKind::Empty => "an empty stream (vacuous)",
261 StreamKind::Body(_) => "the request body",
262 StreamKind::Multipart(_) => "a multipart form field",
263 };
264
265 warn_!("Data limit reached while reading {}.", kind.primary().bold());
266 }
267
268 Pin::new(&mut self.chain).poll_read(cx, buf)
269 }
270}
271
272impl Stream for StreamKind<'_> {
273 type Item = io::Result<hyper::body::Bytes>;
274
275 fn poll_next(
276 self: Pin<&mut Self>,
277 cx: &mut Context<'_>,
278 ) -> Poll<Option<Self::Item>> {
279 match self.get_mut() {
280 StreamKind::Body(body) => Pin::new(body).poll_next(cx)
281 .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
282 StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx)
283 .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
284 StreamKind::Empty => Poll::Ready(None),
285 }
286 }
287
288 fn size_hint(&self) -> (usize, Option<usize>) {
289 match self {
290 StreamKind::Body(body) => body.size_hint(),
291 StreamKind::Multipart(mp) => mp.size_hint(),
292 StreamKind::Empty => (0, Some(0)),
293 }
294 }
295}
296
297impl AsyncRead for StreamReader<'_> {
298 fn poll_read(
299 mut self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 buf: &mut ReadBuf<'_>,
302 ) -> Poll<io::Result<()>> {
303 loop {
304 self.state = match self.state {
305 State::Pending => {
306 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
307 Some(Err(e)) => return Poll::Ready(Err(e)),
308 Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)),
309 None => State::Done,
310 }
311 },
312 State::Partial(ref mut cursor) => {
313 let rem = buf.remaining();
314 match ready!(Pin::new(cursor).poll_read(cx, buf)) {
315 Ok(()) if rem == buf.remaining() => State::Pending,
316 result => return Poll::Ready(result),
317 }
318 }
319 State::Done => return Poll::Ready(Ok(())),
320 }
321 }
322 }
323}