1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use std::{io, time::Duration};

use crate::listener::{Listener, Endpoint};

static DURATION: Duration = Duration::from_millis(250);

pub struct Bounced<L> {
    listener: L,
}

pub trait BouncedExt: Sized {
    fn bounced(self) -> Bounced<Self> {
        Bounced { listener: self }
    }
}

impl<L> BouncedExt for L { }

fn is_recoverable(e: &io::Error) -> bool {
    matches!(e.kind(),
        | io::ErrorKind::ConnectionRefused
        | io::ErrorKind::ConnectionAborted
        | io::ErrorKind::ConnectionReset)
}

impl<L: Listener + Sync> Bounced<L> {
    #[inline]
    pub async fn accept_next(&self) -> <Self as Listener>::Accept {
        loop {
            match self.listener.accept().await {
                Ok(accept) => return accept,
                Err(e) if is_recoverable(&e) => warn!("recoverable connection error: {e}"),
                Err(e) => {
                    warn!("accept error: {e} [retrying in {}ms]", DURATION.as_millis());
                    tokio::time::sleep(DURATION).await;
                }
            };
        }
    }
}

impl<L: Listener + Sync> Listener for Bounced<L> {
    type Accept = L::Accept;

    type Connection = L::Connection;

    async fn accept(&self) -> io::Result<Self::Accept> {
        Ok(self.accept_next().await)
    }

    async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
        self.listener.connect(accept).await
    }

    fn endpoint(&self) -> io::Result<Endpoint> {
        self.listener.endpoint()
    }
}