rocket/util/
chain.rs

1use std::io;
2use std::task::{Poll, Context};
3use std::pin::Pin;
4
5use pin_project_lite::pin_project;
6use tokio::io::{AsyncRead, ReadBuf};
7
8pin_project! {
9    /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
10    #[must_use = "streams do nothing unless polled"]
11    pub struct Chain<T, U> {
12        #[pin]
13        first: Option<T>,
14        #[pin]
15        second: U,
16    }
17}
18
19impl<T, U> Chain<T, U> {
20    pub(crate) fn new(first: T, second: U) -> Self {
21        Self { first: Some(first), second }
22    }
23}
24
25impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
26    /// Gets references to the underlying readers in this `Chain`.
27    pub fn get_ref(&self) -> (Option<&T>, &U) {
28        (self.first.as_ref(), &self.second)
29    }
30}
31
32impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
33    fn poll_read(
34        mut self: Pin<&mut Self>,
35        cx: &mut Context<'_>,
36        buf: &mut ReadBuf<'_>,
37    ) -> Poll<io::Result<()>> {
38        let me = self.as_mut().project();
39        if let Some(first) = me.first.as_pin_mut() {
40            let init_rem = buf.remaining();
41            futures::ready!(first.poll_read(cx, buf))?;
42            if buf.remaining() == init_rem {
43                self.as_mut().project().first.set(None);
44            } else {
45                return Poll::Ready(Ok(()));
46            }
47        }
48
49        let me = self.as_mut().project();
50        me.second.poll_read(cx, buf)
51    }
52}