rocket/data/
transform.rs

1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Poll, Context};
5
6use tokio::io::ReadBuf;
7
8/// Chainable, in-place, streaming data transformer.
9///
10/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
11/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
12/// transforming readers. The head (most upstream part) of the chain is _always_
13/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
14/// composed _only_ of other [`Transform`]s:
15///
16/// ```text
17///                          downstream --->
18///  AsyncRead | Transform | .. | Transform
19/// <---- upstream
20/// ```
21///
22/// When the upstream source makes data available, the
23/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
24/// subset of the filled section added by an upstream data source with
25/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
26/// changing the size of the filled section. For example,
27/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
28/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
29///
30/// Additionally, new data may be added in-place via the traditional approach:
31/// write to (or overwrite) the initialized section of the buffer and mark it as
32/// filled. All of the remaining filled data will be passed to downstream
33/// transforms as "fresh" data. To add data to the end of the (potentially
34/// rewritten) stream, the [`Transform::poll_finish()`] method can be
35/// implemented.
36///
37/// [`AsyncRead`]: tokio::io::AsyncRead
38pub trait Transform {
39    /// Called when data is read from the upstream source. For any given fresh
40    /// data, this method is called only once. [`TransformBuf::fresh()`] is
41    /// guaranteed to contain at least one byte.
42    ///
43    /// While this method is not _async_ (it does not return [`Poll`]), it is
44    /// nevertheless executed in an async context and should respect all such
45    /// restrictions including not blocking.
46    fn transform(
47        self: Pin<&mut Self>,
48        buf: &mut TransformBuf<'_, '_>,
49    ) -> io::Result<()>;
50
51    /// Called when the upstream is finished, that is, it has no more data to
52    /// fill. At this point, the transform becomes an async reader. This method
53    /// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
54    /// may never be called if the upstream does not finish.
55    ///
56    /// The default implementation returns `Poll::Ready(Ok(()))`.
57    ///
58    /// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
59    fn poll_finish(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &mut ReadBuf<'_>,
63    ) -> Poll<io::Result<()>> {
64        let (_, _) = (cx, buf);
65        Poll::Ready(Ok(()))
66    }
67}
68
69/// A buffer of transformable streaming data.
70///
71/// # Overview
72///
73/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
74/// data is always a subset of the filled data, filled data is always a subset
75/// of initialized data, and initialized data is always a subset of the buffer
76/// itself. Both the filled and initialized data sections are guaranteed to be
77/// at the start of the buffer, but the fresh subset is likely to begin
78/// somewhere inside the filled section.
79///
80/// To visualize this, the diagram below represents a possible state for the
81/// byte buffer being tracked. The square `[ ]` brackets represent the complete
82/// buffer, while the curly `{ }` represent the named subset.
83///
84/// ```text
85/// [  { !! fresh !! }                                 ]
86/// { +++ filled +++ }          unfilled               ]
87/// { ----- initialized ------ }     uninitialized     ]
88/// [                    capacity                      ]
89/// ```
90///
91/// The same buffer represented in its true single dimension is below:
92///
93/// ```text
94/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
95/// ```
96///
97/// * `+`: filled (implies initialized)
98/// * `!`: fresh (implies filled)
99/// * `-`: unfilled / initialized (implies initialized)
100/// * `x`: uninitialized (implies unfilled)
101///
102/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
103/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
104/// portions of the byte buffers is what increases the size of the _filled_
105/// section. Because a [`ReadBuf`] may already be partially filled when a reader
106/// adds bytes to it, a mechanism to track where the _newly_ filled portion
107/// exists is needed. This is exactly what the "fresh" section tracks.
108///
109/// [`AsyncRead`]: tokio::io::AsyncRead
110pub struct TransformBuf<'a, 'b> {
111    pub(crate) buf: &'a mut ReadBuf<'b>,
112    pub(crate) cursor: usize,
113}
114
115impl TransformBuf<'_, '_> {
116    /// Returns a borrow to the fresh data: data filled by the upstream source.
117    pub fn fresh(&self) -> &[u8] {
118        &self.filled()[self.cursor..]
119    }
120
121    /// Returns a mutable borrow to the fresh data: data filled by the upstream
122    /// source.
123    pub fn fresh_mut(&mut self) -> &mut [u8] {
124        let cursor = self.cursor;
125        &mut self.filled_mut()[cursor..]
126    }
127
128    /// Spoils the fresh data by resetting the filled section to its value
129    /// before any new data was added. As a result, the data will never be seen
130    /// by any downstream consumer unless it is returned via another mechanism.
131    pub fn spoil(&mut self) {
132        let cursor = self.cursor;
133        self.set_filled(cursor);
134    }
135}
136
137pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
138
139impl Transform for Inspect {
140    fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
141        (self.0)(buf.fresh());
142        Ok(())
143    }
144}
145
146pub struct InPlaceMap(
147    pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>
148);
149
150impl Transform for InPlaceMap {
151    fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>,) -> io::Result<()> {
152        (self.0)(buf)
153    }
154}
155
156impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
157    type Target = ReadBuf<'b>;
158
159    fn deref(&self) -> &Self::Target {
160        self.buf
161    }
162}
163
164impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
165    fn deref_mut(&mut self) -> &mut Self::Target {
166        self.buf
167    }
168}
169
170// TODO: Test chaining various transform combinations:
171//  * consume | consume
172//  * add | consume
173//  * consume | add
174//  * add | add
175// Where `add` is a transformer that adds data to the stream, and `consume` is
176// one that removes data.
177#[cfg(test)]
178#[allow(deprecated)]
179mod tests {
180    use std::hash::SipHasher;
181    use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}};
182
183    use parking_lot::Mutex;
184    use ubyte::ToByteUnit;
185
186    use crate::http::Method;
187    use crate::local::blocking::Client;
188    use crate::fairing::AdHoc;
189    use crate::{route, Route, Data, Response, Request};
190
191    mod hash_transform {
192        use std::io::Cursor;
193        use std::hash::Hasher;
194
195        use tokio::io::AsyncRead;
196
197        use super::super::*;
198
199        pub struct HashTransform<H: Hasher> {
200            pub(crate) hasher: H,
201            pub(crate) hash: Option<Cursor<[u8; 8]>>
202        }
203
204        impl<H: Hasher + Unpin> Transform for HashTransform<H> {
205            fn transform(
206                mut self: Pin<&mut Self>,
207                buf: &mut TransformBuf<'_, '_>,
208            ) -> io::Result<()> {
209                self.hasher.write(buf.fresh());
210                buf.spoil();
211                Ok(())
212            }
213
214            fn poll_finish(
215                mut self: Pin<&mut Self>,
216                cx: &mut Context<'_>,
217                buf: &mut ReadBuf<'_>,
218            ) -> Poll<io::Result<()>> {
219                if self.hash.is_none() {
220                    let hash = self.hasher.finish();
221                    self.hash = Some(Cursor::new(hash.to_be_bytes()));
222                }
223
224                let cursor = self.hash.as_mut().unwrap();
225                Pin::new(cursor).poll_read(cx, buf)
226            }
227        }
228
229        impl crate::Data<'_> {
230            /// Chain an in-place hash [`Transform`] to `self`.
231            pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
232                where H: Unpin + Send + Sync + 'static
233            {
234                self.chain_transform(HashTransform { hasher, hash: None })
235            }
236        }
237    }
238
239    #[test]
240    fn test_transform_series() {
241        fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
242            Box::pin(async move {
243                data.open(128.bytes()).stream_to(tokio::io::sink()).await.expect("read ok");
244                route::Outcome::Success(Response::new())
245            })
246        }
247
248        let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
249        let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
250        let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
251        let rocket = crate::build()
252            .manage(hash.clone())
253            .manage(raw_data.clone())
254            .manage(inspect2.clone())
255            .mount("/", vec![Route::new(Method::Post, "/", handler)])
256            .attach(AdHoc::on_request("transforms", |req, data| Box::pin(async {
257                let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
258                let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
259                let raw_data = req.rocket().state::<Arc<Mutex<Vec<u8>>>>().cloned().unwrap();
260                let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
261                data.chain_inspect(move |bytes| { *raw_data.lock() = bytes.to_vec(); })
262                    .chain_hash_transform(SipHasher::new())
263                    .chain_inspect(move |bytes| {
264                        assert_eq!(bytes.len(), 8);
265                        let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
266                        let value = u64::from_be_bytes(bytes);
267                        hash1.store(value, Ordering::Release);
268                    })
269                    .chain_inspect(move |bytes| {
270                        assert_eq!(bytes.len(), 8);
271                        let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
272                        let value = u64::from_be_bytes(bytes);
273                        let prev = hash2.load(Ordering::Acquire);
274                        assert_eq!(prev, value);
275                        inspect2.fetch_add(1, Ordering::Release);
276                    });
277            })));
278
279        // Make sure nothing has happened yet.
280        assert!(raw_data.lock().is_empty());
281        assert_eq!(hash.load(Ordering::Acquire), 0);
282        assert_eq!(inspect2.load(Ordering::Acquire), 0);
283
284        // Check that nothing happens if the data isn't read.
285        let client = Client::debug(rocket).unwrap();
286        client.get("/").body("Hello, world!").dispatch();
287        assert!(raw_data.lock().is_empty());
288        assert_eq!(hash.load(Ordering::Acquire), 0);
289        assert_eq!(inspect2.load(Ordering::Acquire), 0);
290
291        // Check inspect + hash + inspect + inspect.
292        client.post("/").body("Hello, world!").dispatch();
293        assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
294        assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
295        assert_eq!(inspect2.load(Ordering::Acquire), 1);
296
297        // Check inspect + hash + inspect + inspect, round 2.
298        let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
299        client.post("/").body(string).dispatch();
300        assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
301        assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
302        assert_eq!(inspect2.load(Ordering::Acquire), 2);
303    }
304}