1use std::{io, time::Duration};
2use std::task::{Poll, Context};
3use std::pin::Pin;
4
5use bytes::{Bytes, BytesMut};
6use pin_project_lite::pin_project;
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio::time::{sleep, Sleep};
9
10use futures::stream::Stream;
11use futures::future::{self, Future, FutureExt};
12
13pin_project! {
14 pub struct ReaderStream<R> {
15 #[pin]
16 reader: Option<R>,
17 buf: BytesMut,
18 cap: usize,
19 }
20}
21
22impl<R: AsyncRead> Stream for ReaderStream<R> {
23 type Item = std::io::Result<Bytes>;
24
25 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
26 use tokio_util::io::poll_read_buf;
27
28 let mut this = self.as_mut().project();
29
30 let reader = match this.reader.as_pin_mut() {
31 Some(r) => r,
32 None => return Poll::Ready(None),
33 };
34
35 if this.buf.capacity() == 0 {
36 this.buf.reserve(*this.cap);
37 }
38
39 match poll_read_buf(reader, cx, &mut this.buf) {
40 Poll::Pending => Poll::Pending,
41 Poll::Ready(Err(err)) => {
42 self.project().reader.set(None);
43 Poll::Ready(Some(Err(err)))
44 }
45 Poll::Ready(Ok(0)) => {
46 self.project().reader.set(None);
47 Poll::Ready(None)
48 }
49 Poll::Ready(Ok(_)) => {
50 let chunk = this.buf.split();
51 Poll::Ready(Some(Ok(chunk.freeze())))
52 }
53 }
54 }
55}
56
57pub trait AsyncReadExt: AsyncRead + Sized {
58 fn into_bytes_stream(self, cap: usize) -> ReaderStream<Self> {
59 ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) }
60 }
61}
62
63impl<T: AsyncRead> AsyncReadExt for T { }
64
65pub trait PollExt<T, E> {
66 fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
67 where F: FnOnce(E) -> U;
68}
69
70impl<T, E> PollExt<T, E> for Poll<Option<Result<T, E>>> {
71 fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
73 where F: FnOnce(E) -> U
74 {
75 match self {
76 Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))),
77 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))),
78 Poll::Ready(None) => Poll::Ready(None),
79 Poll::Pending => Poll::Pending,
80 }
81 }
82}
83
84pin_project! {
85 #[must_use = "streams do nothing unless polled"]
87 pub struct Chain<T, U> {
88 #[pin]
89 first: T,
90 #[pin]
91 second: U,
92 done_first: bool,
93 }
94}
95
96impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
97 pub(crate) fn new(first: T, second: U) -> Self {
98 Self { first, second, done_first: false }
99 }
100}
101
102impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
103 pub fn get_ref(&self) -> (&T, &U) {
105 (&self.first, &self.second)
106 }
107}
108
109impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
110 fn poll_read(
111 self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 buf: &mut ReadBuf<'_>,
114 ) -> Poll<io::Result<()>> {
115 let me = self.project();
116
117 if !*me.done_first {
118 let init_rem = buf.remaining();
119 futures::ready!(me.first.poll_read(cx, buf))?;
120 if buf.remaining() == init_rem {
121 *me.done_first = true;
122 } else {
123 return Poll::Ready(Ok(()));
124 }
125 }
126 me.second.poll_read(cx, buf)
127 }
128}
129
130enum State {
131 Active,
133 Grace(Pin<Box<Sleep>>),
135 Mercy(Pin<Box<Sleep>>),
138}
139
140pin_project! {
141 #[must_use = "futures do nothing unless polled"]
143 pub struct CancellableIo<F, I> {
144 #[pin]
145 io: Option<I>,
146 #[pin]
147 trigger: future::Fuse<F>,
148 state: State,
149 grace: Duration,
150 mercy: Duration,
151 }
152}
153
154impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
155 pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
156 CancellableIo {
157 grace, mercy,
158 io: Some(io),
159 trigger: trigger.fuse(),
160 state: State::Active,
161 }
162 }
163
164 pub fn io(&self) -> Option<&I> {
165 self.io.as_ref()
166 }
167
168 fn poll_trigger_then<T>(
170 mut self: Pin<&mut Self>,
171 cx: &mut Context<'_>,
172 do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
173 ) -> Poll<io::Result<T>> {
174 let mut me = self.as_mut().project();
175 let io = match me.io.as_pin_mut() {
176 Some(io) => io,
177 None => return Poll::Ready(Err(gone())),
178 };
179
180 loop {
181 match me.state {
182 State::Active => {
183 if me.trigger.as_mut().poll(cx).is_ready() {
184 *me.state = State::Grace(Box::pin(sleep(*me.grace)));
185 } else {
186 return do_io(io, cx);
187 }
188 }
189 State::Grace(timer) => {
190 if timer.as_mut().poll(cx).is_ready() {
191 *me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
192 } else {
193 return do_io(io, cx);
194 }
195 }
196 State::Mercy(timer) => {
197 if timer.as_mut().poll(cx).is_ready() {
198 self.project().io.set(None);
199 return Poll::Ready(Err(time_out()));
200 } else {
201 let result = futures::ready!(io.poll_shutdown(cx));
202 self.project().io.set(None);
203 return match result {
204 Err(e) => Poll::Ready(Err(e)),
205 Ok(()) => Poll::Ready(Err(gone()))
206 };
207 }
208 },
209 }
210 }
211 }
212}
213
214fn time_out() -> io::Error {
215 io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
216}
217
218fn gone() -> io::Error {
219 io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
220}
221
222impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
223 fn poll_read(
224 mut self: Pin<&mut Self>,
225 cx: &mut Context<'_>,
226 buf: &mut ReadBuf<'_>,
227 ) -> Poll<io::Result<()>> {
228 self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
229 }
230}
231
232impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
233 fn poll_write(
234 mut self: Pin<&mut Self>,
235 cx: &mut Context<'_>,
236 buf: &[u8],
237 ) -> Poll<io::Result<usize>> {
238 self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
239 }
240
241 fn poll_flush(
242 mut self: Pin<&mut Self>,
243 cx: &mut Context<'_>
244 ) -> Poll<io::Result<()>> {
245 self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
246 }
247
248 fn poll_shutdown(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>
251 ) -> Poll<io::Result<()>> {
252 self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
253 }
254
255 fn poll_write_vectored(
256 mut self: Pin<&mut Self>,
257 cx: &mut Context<'_>,
258 bufs: &[io::IoSlice<'_>],
259 ) -> Poll<io::Result<usize>> {
260 self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
261 }
262
263 fn is_write_vectored(&self) -> bool {
264 self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
265 }
266}
267
268use crate::http::private::{Listener, Connection, Certificates};
269
270impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
271 fn peer_address(&self) -> Option<std::net::SocketAddr> {
272 self.io().and_then(|io| io.peer_address())
273 }
274
275 fn peer_certificates(&self) -> Option<Certificates> {
276 self.io().and_then(|io| io.peer_certificates())
277 }
278
279 fn enable_nodelay(&self) -> io::Result<()> {
280 match self.io() {
281 Some(io) => io.enable_nodelay(),
282 None => Ok(())
283 }
284 }
285}
286
287pin_project! {
288 pub struct CancellableListener<F, L> {
289 pub trigger: F,
290 #[pin]
291 pub listener: L,
292 pub grace: Duration,
293 pub mercy: Duration,
294 }
295}
296
297impl<F, L> CancellableListener<F, L> {
298 pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self {
299 let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy));
300 CancellableListener { trigger, listener, grace, mercy }
301 }
302}
303
304impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
305 type Connection = CancellableIo<F, L::Connection>;
306
307 fn local_addr(&self) -> Option<std::net::SocketAddr> {
308 self.listener.local_addr()
309 }
310
311 fn poll_accept(
312 mut self: Pin<&mut Self>,
313 cx: &mut Context<'_>
314 ) -> Poll<io::Result<Self::Connection>> {
315 self.as_mut().project().listener
316 .poll_accept(cx)
317 .map(|res| res.map(|conn| {
318 CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy)
319 }))
320 }
321}
322
323pub trait StreamExt: Sized + Stream {
324 fn join<U>(self, other: U) -> Join<Self, U>
325 where U: Stream<Item = Self::Item>;
326}
327
328impl<S: Stream> StreamExt for S {
329 fn join<U>(self, other: U) -> Join<Self, U>
330 where U: Stream<Item = Self::Item>
331 {
332 Join::new(self, other)
333 }
334}
335
336pin_project! {
337 pub struct Join<T, U> {
339 #[pin]
340 a: T,
341 #[pin]
342 b: U,
343 toggle: bool,
345 done: bool,
347 }
348}
349
350impl<T, U> Join<T, U> {
351 pub(super) fn new(a: T, b: U) -> Join<T, U>
352 where T: Stream, U: Stream,
353 {
354 Join { a, b, toggle: false, done: false, }
355 }
356
357 fn poll_next<A: Stream, B: Stream<Item = A::Item>>(
358 first: Pin<&mut A>,
359 second: Pin<&mut B>,
360 done: &mut bool,
361 cx: &mut Context<'_>,
362 ) -> Poll<Option<A::Item>> {
363 match first.poll_next(cx) {
364 Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
365 Poll::Pending => match second.poll_next(cx) {
366 Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
367 Poll::Pending => Poll::Pending
368 }
369 }
370 }
371}
372
373impl<T, U> Stream for Join<T, U>
374 where T: Stream,
375 U: Stream<Item = T::Item>,
376{
377 type Item = T::Item;
378
379 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
380 if self.done {
381 return Poll::Ready(None);
382 }
383
384 let me = self.project();
385 *me.toggle = !*me.toggle;
386 match *me.toggle {
387 true => Self::poll_next(me.a, me.b, me.done, cx),
388 false => Self::poll_next(me.b, me.a, me.done, cx),
389 }
390 }
391
392 fn size_hint(&self) -> (usize, Option<usize>) {
393 let (left_low, left_high) = self.a.size_hint();
394 let (right_low, right_high) = self.b.size_hint();
395
396 let low = left_low.saturating_add(right_low);
397 let high = match (left_high, right_high) {
398 (Some(h1), Some(h2)) => h1.checked_add(h2),
399 _ => None,
400 };
401
402 (low, high)
403 }
404}