rocket/data/transform.rs
1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Poll, Context};
5
6use tokio::io::ReadBuf;
7
8/// Chainable, in-place, streaming data transformer.
9///
10/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
11/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
12/// transforming readers. The head (most upstream part) of the chain is _always_
13/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
14/// composed _only_ of other [`Transform`]s:
15///
16/// ```text
17/// downstream --->
18/// AsyncRead | Transform | .. | Transform
19/// <---- upstream
20/// ```
21///
22/// When the upstream source makes data available, the
23/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
24/// subset of the filled section added by an upstream data source with
25/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
26/// changing the size of the filled section. For example,
27/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
28/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
29///
30/// Additionally, new data may be added in-place via the traditional approach:
31/// write to (or overwrite) the initialized section of the buffer and mark it as
32/// filled. All of the remaining filled data will be passed to downstream
33/// transforms as "fresh" data. To add data to the end of the (potentially
34/// rewritten) stream, the [`Transform::poll_finish()`] method can be
35/// implemented.
36///
37/// [`AsyncRead`]: tokio::io::AsyncRead
38pub trait Transform {
39 /// Called when data is read from the upstream source. For any given fresh
40 /// data, this method is called only once. [`TransformBuf::fresh()`] is
41 /// guaranteed to contain at least one byte.
42 ///
43 /// While this method is not _async_ (it does not return [`Poll`]), it is
44 /// nevertheless executed in an async context and should respect all such
45 /// restrictions including not blocking.
46 fn transform(
47 self: Pin<&mut Self>,
48 buf: &mut TransformBuf<'_, '_>,
49 ) -> io::Result<()>;
50
51 /// Called when the upstream is finished, that is, it has no more data to
52 /// fill. At this point, the transform becomes an async reader. This method
53 /// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
54 /// may never be called if the upstream does not finish.
55 ///
56 /// The default implementation returns `Poll::Ready(Ok(()))`.
57 ///
58 /// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
59 fn poll_finish(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &mut ReadBuf<'_>,
63 ) -> Poll<io::Result<()>> {
64 let (_, _) = (cx, buf);
65 Poll::Ready(Ok(()))
66 }
67}
68
69/// A buffer of transformable streaming data.
70///
71/// # Overview
72///
73/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
74/// data is always a subset of the filled data, filled data is always a subset
75/// of initialized data, and initialized data is always a subset of the buffer
76/// itself. Both the filled and initialized data sections are guaranteed to be
77/// at the start of the buffer, but the fresh subset is likely to begin
78/// somewhere inside the filled section.
79///
80/// To visualize this, the diagram below represents a possible state for the
81/// byte buffer being tracked. The square `[ ]` brackets represent the complete
82/// buffer, while the curly `{ }` represent the named subset.
83///
84/// ```text
85/// [ { !! fresh !! } ]
86/// { +++ filled +++ } unfilled ]
87/// { ----- initialized ------ } uninitialized ]
88/// [ capacity ]
89/// ```
90///
91/// The same buffer represented in its true single dimension is below:
92///
93/// ```text
94/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
95/// ```
96///
97/// * `+`: filled (implies initialized)
98/// * `!`: fresh (implies filled)
99/// * `-`: unfilled / initialized (implies initialized)
100/// * `x`: uninitialized (implies unfilled)
101///
102/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
103/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
104/// portions of the byte buffers is what increases the size of the _filled_
105/// section. Because a [`ReadBuf`] may already be partially filled when a reader
106/// adds bytes to it, a mechanism to track where the _newly_ filled portion
107/// exists is needed. This is exactly what the "fresh" section tracks.
108///
109/// [`AsyncRead`]: tokio::io::AsyncRead
110pub struct TransformBuf<'a, 'b> {
111 pub(crate) buf: &'a mut ReadBuf<'b>,
112 pub(crate) cursor: usize,
113}
114
115impl TransformBuf<'_, '_> {
116 /// Returns a borrow to the fresh data: data filled by the upstream source.
117 pub fn fresh(&self) -> &[u8] {
118 &self.filled()[self.cursor..]
119 }
120
121 /// Returns a mutable borrow to the fresh data: data filled by the upstream
122 /// source.
123 pub fn fresh_mut(&mut self) -> &mut [u8] {
124 let cursor = self.cursor;
125 &mut self.filled_mut()[cursor..]
126 }
127
128 /// Spoils the fresh data by resetting the filled section to its value
129 /// before any new data was added. As a result, the data will never be seen
130 /// by any downstream consumer unless it is returned via another mechanism.
131 pub fn spoil(&mut self) {
132 let cursor = self.cursor;
133 self.set_filled(cursor);
134 }
135}
136
137pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
138
139impl Transform for Inspect {
140 fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
141 (self.0)(buf.fresh());
142 Ok(())
143 }
144}
145
146pub struct InPlaceMap(
147 pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>
148);
149
150impl Transform for InPlaceMap {
151 fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>,) -> io::Result<()> {
152 (self.0)(buf)
153 }
154}
155
156impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
157 type Target = ReadBuf<'b>;
158
159 fn deref(&self) -> &Self::Target {
160 self.buf
161 }
162}
163
164impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
165 fn deref_mut(&mut self) -> &mut Self::Target {
166 self.buf
167 }
168}
169
170// TODO: Test chaining various transform combinations:
171// * consume | consume
172// * add | consume
173// * consume | add
174// * add | add
175// Where `add` is a transformer that adds data to the stream, and `consume` is
176// one that removes data.
177#[cfg(test)]
178#[allow(deprecated)]
179mod tests {
180 use std::hash::SipHasher;
181 use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}};
182
183 use parking_lot::Mutex;
184 use ubyte::ToByteUnit;
185
186 use crate::http::Method;
187 use crate::local::blocking::Client;
188 use crate::fairing::AdHoc;
189 use crate::{route, Route, Data, Response, Request};
190
191 mod hash_transform {
192 use std::io::Cursor;
193 use std::hash::Hasher;
194
195 use tokio::io::AsyncRead;
196
197 use super::super::*;
198
199 pub struct HashTransform<H: Hasher> {
200 pub(crate) hasher: H,
201 pub(crate) hash: Option<Cursor<[u8; 8]>>
202 }
203
204 impl<H: Hasher + Unpin> Transform for HashTransform<H> {
205 fn transform(
206 mut self: Pin<&mut Self>,
207 buf: &mut TransformBuf<'_, '_>,
208 ) -> io::Result<()> {
209 self.hasher.write(buf.fresh());
210 buf.spoil();
211 Ok(())
212 }
213
214 fn poll_finish(
215 mut self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 buf: &mut ReadBuf<'_>,
218 ) -> Poll<io::Result<()>> {
219 if self.hash.is_none() {
220 let hash = self.hasher.finish();
221 self.hash = Some(Cursor::new(hash.to_be_bytes()));
222 }
223
224 let cursor = self.hash.as_mut().unwrap();
225 Pin::new(cursor).poll_read(cx, buf)
226 }
227 }
228
229 impl crate::Data<'_> {
230 /// Chain an in-place hash [`Transform`] to `self`.
231 pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
232 where H: Unpin + Send + Sync + 'static
233 {
234 self.chain_transform(HashTransform { hasher, hash: None })
235 }
236 }
237 }
238
239 #[test]
240 fn test_transform_series() {
241 fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
242 Box::pin(async move {
243 data.open(128.bytes()).stream_to(tokio::io::sink()).await.expect("read ok");
244 route::Outcome::Success(Response::new())
245 })
246 }
247
248 let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
249 let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
250 let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
251 let rocket = crate::build()
252 .manage(hash.clone())
253 .manage(raw_data.clone())
254 .manage(inspect2.clone())
255 .mount("/", vec![Route::new(Method::Post, "/", handler)])
256 .attach(AdHoc::on_request("transforms", |req, data| Box::pin(async {
257 let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
258 let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
259 let raw_data = req.rocket().state::<Arc<Mutex<Vec<u8>>>>().cloned().unwrap();
260 let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
261 data.chain_inspect(move |bytes| { *raw_data.lock() = bytes.to_vec(); })
262 .chain_hash_transform(SipHasher::new())
263 .chain_inspect(move |bytes| {
264 assert_eq!(bytes.len(), 8);
265 let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
266 let value = u64::from_be_bytes(bytes);
267 hash1.store(value, Ordering::Release);
268 })
269 .chain_inspect(move |bytes| {
270 assert_eq!(bytes.len(), 8);
271 let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
272 let value = u64::from_be_bytes(bytes);
273 let prev = hash2.load(Ordering::Acquire);
274 assert_eq!(prev, value);
275 inspect2.fetch_add(1, Ordering::Release);
276 });
277 })));
278
279 // Make sure nothing has happened yet.
280 assert!(raw_data.lock().is_empty());
281 assert_eq!(hash.load(Ordering::Acquire), 0);
282 assert_eq!(inspect2.load(Ordering::Acquire), 0);
283
284 // Check that nothing happens if the data isn't read.
285 let client = Client::debug(rocket).unwrap();
286 client.get("/").body("Hello, world!").dispatch();
287 assert!(raw_data.lock().is_empty());
288 assert_eq!(hash.load(Ordering::Acquire), 0);
289 assert_eq!(inspect2.load(Ordering::Acquire), 0);
290
291 // Check inspect + hash + inspect + inspect.
292 client.post("/").body("Hello, world!").dispatch();
293 assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
294 assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
295 assert_eq!(inspect2.load(Ordering::Acquire), 1);
296
297 // Check inspect + hash + inspect + inspect, round 2.
298 let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
299 client.post("/").body(string).dispatch();
300 assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
301 assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
302 assert_eq!(inspect2.load(Ordering::Acquire), 2);
303 }
304}