rocket/listener/
cancellable.rs

1use std::io;
2use std::task::{Poll, Context};
3use std::pin::Pin;
4
5use futures::Stream;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7use futures::future::FutureExt;
8use pin_project_lite::pin_project;
9
10use crate::shutdown::Stages;
11
12pin_project! {
13    /// I/O that can be cancelled when a future `F` resolves.
14    #[must_use = "futures do nothing unless polled"]
15    pub struct Cancellable<I> {
16        #[pin]
17        io: Option<I>,
18        stages: Stages,
19        state: State,
20    }
21}
22
23#[derive(Debug)]
24enum State {
25    /// I/O has not been cancelled. Proceed as normal until `Shutdown`.
26    Active,
27    /// I/O has been cancelled. Try to finish before `Shutdown`.
28    Grace,
29    /// Grace has elapsed. Shutdown connections. After `Shutdown`, force close.
30    Mercy,
31}
32
33pub trait CancellableExt: Sized {
34    fn cancellable(self, stages: Stages) -> Cancellable<Self> {
35        Cancellable {
36            io: Some(self),
37            state: State::Active,
38            stages,
39        }
40    }
41}
42
43impl<T> CancellableExt for T { }
44
45fn time_out() -> io::Error {
46    io::Error::new(io::ErrorKind::TimedOut, "shutdown grace period elapsed")
47}
48
49fn gone() -> io::Error {
50    io::Error::new(io::ErrorKind::BrokenPipe, "I/O driver terminated")
51}
52
53impl<I: AsyncCancel> Cancellable<I> {
54    pub fn inner(&self) -> Option<&I> {
55        self.io.as_ref()
56    }
57}
58
59pub trait AsyncCancel {
60    fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
61}
62
63impl<T: AsyncWrite> AsyncCancel for T {
64    fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
65        <T as AsyncWrite>::poll_shutdown(self, cx)
66    }
67}
68
69impl<I: AsyncCancel> Cancellable<I> {
70    /// Run `do_io` while connection processing should continue.
71    pub fn poll_with<T>(
72        mut self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74        do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
75    ) -> Poll<io::Result<T>> {
76        let me = self.as_mut().project();
77        let io = match me.io.as_pin_mut() {
78            Some(io) => io,
79            None => return Poll::Ready(Err(gone())),
80        };
81
82        loop {
83            match me.state {
84                State::Active => {
85                    if me.stages.start.poll_unpin(cx).is_ready() {
86                        *me.state = State::Grace;
87                    } else {
88                        return do_io(io, cx);
89                    }
90                }
91                State::Grace => {
92                    if me.stages.grace.poll_unpin(cx).is_ready() {
93                        *me.state = State::Mercy;
94                    } else {
95                        return do_io(io, cx);
96                    }
97                }
98                State::Mercy => {
99                    if me.stages.mercy.poll_unpin(cx).is_ready() {
100                        self.project().io.set(None);
101                        return Poll::Ready(Err(time_out()));
102                    } else {
103                        let result = futures::ready!(io.poll_cancel(cx));
104                        self.project().io.set(None);
105                        return match result {
106                            Ok(()) => Poll::Ready(Err(gone())),
107                            Err(e) => Poll::Ready(Err(e)),
108                        };
109                    }
110                },
111            }
112        }
113    }
114}
115
116impl<I: AsyncRead + AsyncCancel> AsyncRead for Cancellable<I> {
117    fn poll_read(
118        self: Pin<&mut Self>,
119        cx: &mut Context<'_>,
120        buf: &mut ReadBuf<'_>,
121    ) -> Poll<io::Result<()>> {
122        self.poll_with(cx, |io, cx| io.poll_read(cx, buf))
123    }
124}
125
126impl<I: AsyncWrite> AsyncWrite for Cancellable<I> {
127    fn poll_write(
128        self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        buf: &[u8],
131    ) -> Poll<io::Result<usize>> {
132        self.poll_with(cx, |io, cx| io.poll_write(cx, buf))
133    }
134
135    fn poll_flush(
136        self: Pin<&mut Self>,
137        cx: &mut Context<'_>
138    ) -> Poll<io::Result<()>> {
139        self.poll_with(cx, |io, cx| io.poll_flush(cx))
140    }
141
142    fn poll_shutdown(
143        self: Pin<&mut Self>,
144        cx: &mut Context<'_>
145    ) -> Poll<io::Result<()>> {
146        self.poll_with(cx, |io, cx| io.poll_shutdown(cx))
147    }
148
149    fn poll_write_vectored(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        bufs: &[io::IoSlice<'_>],
153    ) -> Poll<io::Result<usize>> {
154        self.poll_with(cx, |io, cx| io.poll_write_vectored(cx, bufs))
155    }
156
157    fn is_write_vectored(&self) -> bool {
158        self.inner().map(|io| io.is_write_vectored()).unwrap_or(false)
159    }
160}
161
162impl<T, I: Stream<Item = io::Result<T>> + AsyncCancel> Stream for Cancellable<I> {
163    type Item = I::Item;
164
165    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
166        use futures::ready;
167
168        match ready!(self.poll_with(cx, |io, cx| io.poll_next(cx).map(Ok))) {
169            Ok(Some(v)) => Poll::Ready(Some(v)),
170            Ok(None) => Poll::Ready(None),
171            Err(e) => Poll::Ready(Some(Err(e))),
172        }
173    }
174}