rocket/shutdown/
tripwire.rs

1use std::fmt;
2use std::{ops::Deref, pin::Pin, future::Future};
3use std::task::{Context, Poll};
4use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
5
6use futures::future::FusedFuture;
7use tokio::sync::futures::Notified;
8use tokio::sync::Notify;
9
10#[doc(hidden)]
11pub struct State {
12    tripped: AtomicBool,
13    notify: Notify,
14}
15
16#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
17pub struct TripWire {
18    state: Arc<State>,
19    // `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it.
20    event: Option<Pin<Box<Notified<'static>>>>,
21}
22
23impl Deref for TripWire {
24    type Target = State;
25
26    fn deref(&self) -> &Self::Target {
27        &self.state
28    }
29}
30
31impl Clone for TripWire {
32    fn clone(&self) -> Self {
33        TripWire {
34            state: self.state.clone(),
35            event: None
36        }
37    }
38}
39
40impl Drop for TripWire {
41    fn drop(&mut self) {
42        // SAFETY: Ensure we drop the self-reference before `self`.
43        self.event = None;
44    }
45}
46
47impl fmt::Debug for TripWire {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        f.debug_struct("TripWire")
50            .field("tripped", &self.tripped)
51            .finish()
52    }
53}
54
55impl Future for TripWire {
56    type Output = ();
57
58    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
59        if self.tripped() {
60            self.event = None;
61            return Poll::Ready(());
62        }
63
64        if self.event.is_none() {
65            let notified = self.state.notify.notified();
66
67            // SAFETY: This is a self reference to the `state`.
68            self.event = Some(Box::pin(unsafe { std::mem::transmute(notified) }));
69        }
70
71        if let Some(ref mut event) = self.event {
72            // The order here is important! We need to know:
73            //   !self.tripped() => not notified == notified => self.tripped()
74            if event.as_mut().poll(cx).is_ready() || self.tripped() {
75                self.event = None;
76                return Poll::Ready(());
77            }
78        }
79
80        Poll::Pending
81    }
82}
83
84impl FusedFuture for TripWire {
85    fn is_terminated(&self) -> bool {
86        self.tripped()
87    }
88}
89
90impl TripWire {
91    pub fn new() -> Self {
92        TripWire {
93            state: Arc::new(State {
94                tripped: AtomicBool::new(false),
95                notify: Notify::new()
96            }),
97            event: None,
98        }
99    }
100
101    pub fn trip(&self) {
102        self.tripped.store(true, Ordering::Release);
103        self.notify.notify_waiters();
104    }
105
106    #[inline(always)]
107    pub fn tripped(&self) -> bool {
108        self.tripped.load(Ordering::Acquire)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::TripWire;
115
116    #[test]
117    fn ensure_is_send_sync_clone_unpin() {
118        fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
119        is_send_sync_clone_unpin::<TripWire>();
120    }
121
122    #[tokio::test]
123    async fn simple_trip() {
124        let wire = TripWire::new();
125        wire.trip();
126        wire.await;
127    }
128
129    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
130    async fn no_trip() {
131        use tokio::time::{sleep, Duration};
132        use futures::stream::{FuturesUnordered as Set, StreamExt};
133        use futures::future::{BoxFuture, FutureExt};
134
135        let wire = TripWire::new();
136        let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
137        for _ in 0..10 {
138            futs.push(Box::pin(wire.clone().map(|_| false)));
139        }
140
141        let sleep = sleep(Duration::from_secs(1));
142        futs.push(Box::pin(sleep.map(|_| true)));
143        assert!(futs.next().await.unwrap());
144    }
145
146    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
147    async fn general_trip() {
148        let wire = TripWire::new();
149        let mut tasks = vec![];
150        for _ in 0..1000 {
151            tasks.push(tokio::spawn(wire.clone()));
152            tokio::task::yield_now().await;
153        }
154
155        wire.trip();
156        for task in tasks {
157            task.await.unwrap();
158        }
159    }
160
161    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
162    async fn single_stage_trip() {
163        let mut tasks = vec![];
164        for i in 0..1000 {
165            // Trip once every 100. 50 will be left "untripped", but should be.
166            if i % 2 == 0 {
167                let wire = TripWire::new();
168                tasks.push(tokio::spawn(wire.clone()));
169                tasks.push(tokio::spawn(async move { wire.trip() }));
170            } else {
171                let wire = TripWire::new();
172                let wire2 = wire.clone();
173                tasks.push(tokio::spawn(async move { wire.trip() }));
174                tasks.push(tokio::spawn(wire2));
175            }
176        }
177
178        for task in tasks {
179            task.await.unwrap();
180        }
181    }
182
183    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
184    async fn staged_trip() {
185        let wire = TripWire::new();
186        let mut tasks = vec![];
187        for i in 0..1050 {
188            let wire = wire.clone();
189            // Trip once every 100. 50 will be left "untripped", but should be.
190            let task = if i % 100 == 0 {
191                tokio::spawn(async move { wire.trip() })
192            } else {
193                tokio::spawn(wire)
194            };
195
196            if i % 20 == 0 {
197                tokio::task::yield_now().await;
198            }
199
200            tasks.push(task);
201        }
202
203        for task in tasks {
204            task.await.unwrap();
205        }
206    }
207}