rocket/
ext.rs

1use std::{io, time::Duration};
2use std::task::{Poll, Context};
3use std::pin::Pin;
4
5use bytes::{Bytes, BytesMut};
6use pin_project_lite::pin_project;
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio::time::{sleep, Sleep};
9
10use futures::stream::Stream;
11use futures::future::{self, Future, FutureExt};
12
13pin_project! {
14    pub struct ReaderStream<R> {
15        #[pin]
16        reader: Option<R>,
17        buf: BytesMut,
18        cap: usize,
19    }
20}
21
22impl<R: AsyncRead> Stream for ReaderStream<R> {
23    type Item = std::io::Result<Bytes>;
24
25    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
26        use tokio_util::io::poll_read_buf;
27
28        let mut this = self.as_mut().project();
29
30        let reader = match this.reader.as_pin_mut() {
31            Some(r) => r,
32            None => return Poll::Ready(None),
33        };
34
35        if this.buf.capacity() == 0 {
36            this.buf.reserve(*this.cap);
37        }
38
39        match poll_read_buf(reader, cx, &mut this.buf) {
40            Poll::Pending => Poll::Pending,
41            Poll::Ready(Err(err)) => {
42                self.project().reader.set(None);
43                Poll::Ready(Some(Err(err)))
44            }
45            Poll::Ready(Ok(0)) => {
46                self.project().reader.set(None);
47                Poll::Ready(None)
48            }
49            Poll::Ready(Ok(_)) => {
50                let chunk = this.buf.split();
51                Poll::Ready(Some(Ok(chunk.freeze())))
52            }
53        }
54    }
55}
56
57pub trait AsyncReadExt: AsyncRead + Sized {
58    fn into_bytes_stream(self, cap: usize) -> ReaderStream<Self> {
59        ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) }
60    }
61}
62
63impl<T: AsyncRead> AsyncReadExt for T { }
64
65pub trait PollExt<T, E> {
66    fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
67        where F: FnOnce(E) -> U;
68}
69
70impl<T, E> PollExt<T, E> for Poll<Option<Result<T, E>>> {
71    /// Changes the error value of this `Poll` with the closure provided.
72    fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
73        where F: FnOnce(E) -> U
74    {
75        match self {
76            Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))),
77            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))),
78            Poll::Ready(None) => Poll::Ready(None),
79            Poll::Pending => Poll::Pending,
80        }
81    }
82}
83
84pin_project! {
85    /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
86    #[must_use = "streams do nothing unless polled"]
87    pub struct Chain<T, U> {
88        #[pin]
89        first: T,
90        #[pin]
91        second: U,
92        done_first: bool,
93    }
94}
95
96impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
97    pub(crate) fn new(first: T, second: U) -> Self {
98        Self { first, second, done_first: false }
99    }
100}
101
102impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
103    /// Gets references to the underlying readers in this `Chain`.
104    pub fn get_ref(&self) -> (&T, &U) {
105        (&self.first, &self.second)
106    }
107}
108
109impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
110    fn poll_read(
111        self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113        buf: &mut ReadBuf<'_>,
114    ) -> Poll<io::Result<()>> {
115        let me = self.project();
116
117        if !*me.done_first {
118            let init_rem = buf.remaining();
119            futures::ready!(me.first.poll_read(cx, buf))?;
120            if buf.remaining() == init_rem {
121                *me.done_first = true;
122            } else {
123                return Poll::Ready(Ok(()));
124            }
125        }
126        me.second.poll_read(cx, buf)
127    }
128}
129
130enum State {
131    /// I/O has not been cancelled. Proceed as normal.
132    Active,
133    /// I/O has been cancelled. See if we can finish before the timer expires.
134    Grace(Pin<Box<Sleep>>),
135    /// Grace period elapsed. Shutdown the connection, waiting for the timer
136    /// until we force close.
137    Mercy(Pin<Box<Sleep>>),
138}
139
140pin_project! {
141    /// I/O that can be cancelled when a future `F` resolves.
142    #[must_use = "futures do nothing unless polled"]
143    pub struct CancellableIo<F, I> {
144        #[pin]
145        io: Option<I>,
146        #[pin]
147        trigger: future::Fuse<F>,
148        state: State,
149        grace: Duration,
150        mercy: Duration,
151    }
152}
153
154impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
155    pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
156        CancellableIo {
157            grace, mercy,
158            io: Some(io),
159            trigger: trigger.fuse(),
160            state: State::Active,
161        }
162    }
163
164    pub fn io(&self) -> Option<&I> {
165        self.io.as_ref()
166    }
167
168    /// Run `do_io` while connection processing should continue.
169    fn poll_trigger_then<T>(
170        mut self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172        do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
173    ) -> Poll<io::Result<T>> {
174        let mut me = self.as_mut().project();
175        let io = match me.io.as_pin_mut() {
176            Some(io) => io,
177            None => return Poll::Ready(Err(gone())),
178        };
179
180        loop {
181            match me.state {
182                State::Active => {
183                    if me.trigger.as_mut().poll(cx).is_ready() {
184                        *me.state = State::Grace(Box::pin(sleep(*me.grace)));
185                    } else {
186                        return do_io(io, cx);
187                    }
188                }
189                State::Grace(timer) => {
190                    if timer.as_mut().poll(cx).is_ready() {
191                        *me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
192                    } else {
193                        return do_io(io, cx);
194                    }
195                }
196                State::Mercy(timer) => {
197                    if timer.as_mut().poll(cx).is_ready() {
198                        self.project().io.set(None);
199                        return Poll::Ready(Err(time_out()));
200                    } else {
201                        let result = futures::ready!(io.poll_shutdown(cx));
202                        self.project().io.set(None);
203                        return match result {
204                            Err(e) => Poll::Ready(Err(e)),
205                            Ok(()) => Poll::Ready(Err(gone()))
206                        };
207                    }
208                },
209            }
210        }
211    }
212}
213
214fn time_out() -> io::Error {
215    io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
216}
217
218fn gone() -> io::Error {
219    io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
220}
221
222impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
223    fn poll_read(
224        mut self: Pin<&mut Self>,
225        cx: &mut Context<'_>,
226        buf: &mut ReadBuf<'_>,
227    ) -> Poll<io::Result<()>> {
228        self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
229    }
230}
231
232impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
233    fn poll_write(
234        mut self: Pin<&mut Self>,
235        cx: &mut Context<'_>,
236        buf: &[u8],
237    ) -> Poll<io::Result<usize>> {
238        self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
239    }
240
241    fn poll_flush(
242        mut self: Pin<&mut Self>,
243        cx: &mut Context<'_>
244    ) -> Poll<io::Result<()>> {
245        self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
246    }
247
248    fn poll_shutdown(
249        mut self: Pin<&mut Self>,
250        cx: &mut Context<'_>
251    ) -> Poll<io::Result<()>> {
252        self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
253    }
254
255    fn poll_write_vectored(
256        mut self: Pin<&mut Self>,
257        cx: &mut Context<'_>,
258        bufs: &[io::IoSlice<'_>],
259    ) -> Poll<io::Result<usize>> {
260        self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
261    }
262
263    fn is_write_vectored(&self) -> bool {
264        self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
265    }
266}
267
268use crate::http::private::{Listener, Connection, Certificates};
269
270impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
271    fn peer_address(&self) -> Option<std::net::SocketAddr> {
272        self.io().and_then(|io| io.peer_address())
273    }
274
275    fn peer_certificates(&self) -> Option<Certificates> {
276        self.io().and_then(|io| io.peer_certificates())
277    }
278
279    fn enable_nodelay(&self) -> io::Result<()> {
280        match self.io() {
281            Some(io) => io.enable_nodelay(),
282            None => Ok(())
283        }
284    }
285}
286
287pin_project! {
288    pub struct CancellableListener<F, L> {
289        pub trigger: F,
290        #[pin]
291        pub listener: L,
292        pub grace: Duration,
293        pub mercy: Duration,
294    }
295}
296
297impl<F, L> CancellableListener<F, L> {
298    pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self {
299        let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy));
300        CancellableListener { trigger, listener, grace, mercy }
301    }
302}
303
304impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
305    type Connection = CancellableIo<F, L::Connection>;
306
307    fn local_addr(&self) -> Option<std::net::SocketAddr> {
308        self.listener.local_addr()
309    }
310
311    fn poll_accept(
312        mut self: Pin<&mut Self>,
313        cx: &mut Context<'_>
314    ) -> Poll<io::Result<Self::Connection>> {
315        self.as_mut().project().listener
316            .poll_accept(cx)
317            .map(|res| res.map(|conn| {
318                CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy)
319            }))
320    }
321}
322
323pub trait StreamExt: Sized + Stream {
324    fn join<U>(self, other: U) -> Join<Self, U>
325        where U: Stream<Item = Self::Item>;
326}
327
328impl<S: Stream> StreamExt for S {
329    fn join<U>(self, other: U) -> Join<Self, U>
330        where U: Stream<Item = Self::Item>
331    {
332        Join::new(self, other)
333    }
334}
335
336pin_project! {
337    /// Stream returned by the [`join`](super::StreamExt::join) method.
338    pub struct Join<T, U> {
339        #[pin]
340        a: T,
341        #[pin]
342        b: U,
343        // When `true`, poll `a` first, otherwise, `poll` b`.
344        toggle: bool,
345        // Set when either `a` or `b` return `None`.
346        done: bool,
347    }
348}
349
350impl<T, U> Join<T, U> {
351    pub(super) fn new(a: T, b: U) -> Join<T, U>
352        where T: Stream, U: Stream,
353    {
354        Join { a, b, toggle: false, done: false, }
355    }
356
357    fn poll_next<A: Stream, B: Stream<Item = A::Item>>(
358        first: Pin<&mut A>,
359        second: Pin<&mut B>,
360        done: &mut bool,
361        cx: &mut Context<'_>,
362    ) -> Poll<Option<A::Item>> {
363        match first.poll_next(cx) {
364            Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
365            Poll::Pending => match second.poll_next(cx) {
366                Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
367                Poll::Pending => Poll::Pending
368            }
369        }
370    }
371}
372
373impl<T, U> Stream for Join<T, U>
374    where T: Stream,
375          U: Stream<Item = T::Item>,
376{
377    type Item = T::Item;
378
379    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
380        if self.done {
381            return Poll::Ready(None);
382        }
383
384        let me = self.project();
385        *me.toggle = !*me.toggle;
386        match *me.toggle {
387            true => Self::poll_next(me.a, me.b, me.done, cx),
388            false => Self::poll_next(me.b, me.a, me.done, cx),
389        }
390    }
391
392    fn size_hint(&self) -> (usize, Option<usize>) {
393        let (left_low, left_high) = self.a.size_hint();
394        let (right_low, right_high) = self.b.size_hint();
395
396        let low = left_low.saturating_add(right_low);
397        let high = match (left_high, right_high) {
398            (Some(h1), Some(h2)) => h1.checked_add(h2),
399            _ => None,
400        };
401
402        (low, high)
403    }
404}