rocket/shutdown/
tripwire.rs
1use std::fmt;
2use std::{ops::Deref, pin::Pin, future::Future};
3use std::task::{Context, Poll};
4use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
5
6use futures::future::FusedFuture;
7use tokio::sync::futures::Notified;
8use tokio::sync::Notify;
9
10#[doc(hidden)]
11pub struct State {
12 tripped: AtomicBool,
13 notify: Notify,
14}
15
16#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
17pub struct TripWire {
18 state: Arc<State>,
19 event: Option<Pin<Box<Notified<'static>>>>,
21}
22
23impl Deref for TripWire {
24 type Target = State;
25
26 fn deref(&self) -> &Self::Target {
27 &self.state
28 }
29}
30
31impl Clone for TripWire {
32 fn clone(&self) -> Self {
33 TripWire {
34 state: self.state.clone(),
35 event: None
36 }
37 }
38}
39
40impl Drop for TripWire {
41 fn drop(&mut self) {
42 self.event = None;
44 }
45}
46
47impl fmt::Debug for TripWire {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("TripWire")
50 .field("tripped", &self.tripped)
51 .finish()
52 }
53}
54
55impl Future for TripWire {
56 type Output = ();
57
58 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
59 if self.tripped() {
60 self.event = None;
61 return Poll::Ready(());
62 }
63
64 if self.event.is_none() {
65 let notified = self.state.notify.notified();
66
67 self.event = Some(Box::pin(unsafe { std::mem::transmute(notified) }));
69 }
70
71 if let Some(ref mut event) = self.event {
72 if event.as_mut().poll(cx).is_ready() || self.tripped() {
75 self.event = None;
76 return Poll::Ready(());
77 }
78 }
79
80 Poll::Pending
81 }
82}
83
84impl FusedFuture for TripWire {
85 fn is_terminated(&self) -> bool {
86 self.tripped()
87 }
88}
89
90impl TripWire {
91 pub fn new() -> Self {
92 TripWire {
93 state: Arc::new(State {
94 tripped: AtomicBool::new(false),
95 notify: Notify::new()
96 }),
97 event: None,
98 }
99 }
100
101 pub fn trip(&self) {
102 self.tripped.store(true, Ordering::Release);
103 self.notify.notify_waiters();
104 }
105
106 #[inline(always)]
107 pub fn tripped(&self) -> bool {
108 self.tripped.load(Ordering::Acquire)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::TripWire;
115
116 #[test]
117 fn ensure_is_send_sync_clone_unpin() {
118 fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
119 is_send_sync_clone_unpin::<TripWire>();
120 }
121
122 #[tokio::test]
123 async fn simple_trip() {
124 let wire = TripWire::new();
125 wire.trip();
126 wire.await;
127 }
128
129 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
130 async fn no_trip() {
131 use tokio::time::{sleep, Duration};
132 use futures::stream::{FuturesUnordered as Set, StreamExt};
133 use futures::future::{BoxFuture, FutureExt};
134
135 let wire = TripWire::new();
136 let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
137 for _ in 0..10 {
138 futs.push(Box::pin(wire.clone().map(|_| false)));
139 }
140
141 let sleep = sleep(Duration::from_secs(1));
142 futs.push(Box::pin(sleep.map(|_| true)));
143 assert!(futs.next().await.unwrap());
144 }
145
146 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
147 async fn general_trip() {
148 let wire = TripWire::new();
149 let mut tasks = vec![];
150 for _ in 0..1000 {
151 tasks.push(tokio::spawn(wire.clone()));
152 tokio::task::yield_now().await;
153 }
154
155 wire.trip();
156 for task in tasks {
157 task.await.unwrap();
158 }
159 }
160
161 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
162 async fn single_stage_trip() {
163 let mut tasks = vec![];
164 for i in 0..1000 {
165 if i % 2 == 0 {
167 let wire = TripWire::new();
168 tasks.push(tokio::spawn(wire.clone()));
169 tasks.push(tokio::spawn(async move { wire.trip() }));
170 } else {
171 let wire = TripWire::new();
172 let wire2 = wire.clone();
173 tasks.push(tokio::spawn(async move { wire.trip() }));
174 tasks.push(tokio::spawn(wire2));
175 }
176 }
177
178 for task in tasks {
179 task.await.unwrap();
180 }
181 }
182
183 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
184 async fn staged_trip() {
185 let wire = TripWire::new();
186 let mut tasks = vec![];
187 for i in 0..1050 {
188 let wire = wire.clone();
189 let task = if i % 100 == 0 {
191 tokio::spawn(async move { wire.trip() })
192 } else {
193 tokio::spawn(wire)
194 };
195
196 if i % 20 == 0 {
197 tokio::task::yield_now().await;
198 }
199
200 tasks.push(task);
201 }
202
203 for task in tasks {
204 task.await.unwrap();
205 }
206 }
207}