rocket/
erased.rs

1use std::io;
2use std::mem::transmute;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Poll, Context};
6
7use futures::future::BoxFuture;
8use http::request::Parts;
9use tokio::io::{AsyncRead, ReadBuf};
10
11use crate::data::{Data, IoHandler, RawStream};
12use crate::{Request, Response, Rocket, Orbit};
13
14// TODO: Magic with trait async fn to get rid of the box pin.
15// TODO: Write safety proofs.
16
17macro_rules! static_assert_covariance {
18    ($($T:tt)*) => (
19        const _: () = {
20            fn _assert_covariance<'x: 'y, 'y>(x: &'y $($T)*<'x>) -> &'y $($T)*<'y> { x }
21        };
22    )
23}
24
25#[derive(Debug)]
26pub struct ErasedRequest {
27    // XXX: SAFETY: This (dependent) field must come first due to drop order!
28    request: Request<'static>,
29    _rocket: Arc<Rocket<Orbit>>,
30    _parts: Box<Parts>,
31}
32
33impl Drop for ErasedRequest {
34    fn drop(&mut self) { }
35}
36
37#[derive(Debug)]
38pub struct ErasedResponse {
39    // XXX: SAFETY: This (dependent) field must come first due to drop order!
40    response: Response<'static>,
41    _request: Arc<ErasedRequest>,
42}
43
44impl Drop for ErasedResponse {
45    fn drop(&mut self) { }
46}
47
48pub struct ErasedIoHandler {
49    // XXX: SAFETY: This (dependent) field must come first due to drop order!
50    io: Box<dyn IoHandler + 'static>,
51    _request: Arc<ErasedRequest>,
52}
53
54impl Drop for ErasedIoHandler {
55    fn drop(&mut self) { }
56}
57
58impl ErasedRequest {
59    pub fn new(
60        rocket: Arc<Rocket<Orbit>>,
61        parts: Parts,
62        constructor: impl for<'r> FnOnce(
63            &'r Rocket<Orbit>,
64            &'r Parts
65        ) -> Request<'r>,
66    ) -> ErasedRequest {
67        let rocket: Arc<Rocket<Orbit>> = rocket;
68        let parts: Box<Parts> = Box::new(parts);
69        let request: Request<'_> = {
70            let rocket: &Rocket<Orbit> = &rocket;
71            let rocket: &'static Rocket<Orbit> = unsafe { transmute(rocket) };
72            let parts: &Parts = &parts;
73            let parts: &'static Parts = unsafe { transmute(parts) };
74            constructor(rocket, parts)
75        };
76
77        ErasedRequest { _rocket: rocket, _parts: parts, request, }
78    }
79
80    pub fn inner(&self) -> &Request<'_> {
81        static_assert_covariance!(Request);
82        &self.request
83    }
84
85    pub async fn into_response<T, D>(
86        self,
87        raw_stream: D,
88        preprocess: impl for<'r, 'x> FnOnce(
89            &'r Rocket<Orbit>,
90            &'r mut Request<'x>,
91            &'r mut Data<'x>
92        ) -> BoxFuture<'r, T>,
93        dispatch: impl for<'r> FnOnce(
94            T,
95            &'r Rocket<Orbit>,
96            &'r Request<'r>,
97            Data<'r>
98        ) -> BoxFuture<'r, Response<'r>>,
99    ) -> ErasedResponse
100        where T: Send + Sync + 'static,
101              D: for<'r> Into<RawStream<'r>>
102    {
103        let mut data: Data<'_> = Data::from(raw_stream);
104        let mut parent = Arc::new(self);
105        let token: T = {
106            let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
107            let rocket: &Rocket<Orbit> = &parent._rocket;
108            let request: &mut Request<'_> = &mut parent.request;
109            let data: &mut Data<'_> = &mut data;
110            preprocess(rocket, request, data).await
111        };
112
113        let parent = parent;
114        let response: Response<'_> = {
115            let parent: &ErasedRequest = &parent;
116            let parent: &'static ErasedRequest = unsafe { transmute(parent) };
117            let rocket: &Rocket<Orbit> = &parent._rocket;
118            let request: &Request<'_> = &parent.request;
119            dispatch(token, rocket, request, data).await
120        };
121
122        ErasedResponse {
123            _request: parent,
124            response,
125        }
126    }
127}
128
129impl ErasedResponse {
130    pub fn inner(&self) -> &Response<'_> {
131        static_assert_covariance!(Response);
132        &self.response
133    }
134
135    pub fn with_inner_mut<'a, T>(
136        &'a mut self,
137        f: impl for<'r> FnOnce(&'a mut Response<'r>) -> T
138    ) -> T {
139        static_assert_covariance!(Response);
140        f(&mut self.response)
141    }
142
143    pub fn make_io_handler<'a, T: 'static>(
144        &'a mut self,
145        constructor: impl for<'r> FnOnce(
146            &'r Request<'r>,
147            &'a mut Response<'r>,
148        ) -> Option<(T, Box<dyn IoHandler + 'r>)>
149    ) -> Option<(T, ErasedIoHandler)> {
150        let parent: Arc<ErasedRequest> = self._request.clone();
151        let io: Option<(T, Box<dyn IoHandler + '_>)> = {
152            let parent: &ErasedRequest = &parent;
153            let parent: &'static ErasedRequest = unsafe { transmute(parent) };
154            let request: &Request<'_> = &parent.request;
155            constructor(request, &mut self.response)
156        };
157
158        io.map(|(v, io)| (v, ErasedIoHandler { _request: parent, io }))
159    }
160}
161
162impl ErasedIoHandler {
163    pub fn with_inner_mut<'a, T: 'a>(
164        &'a mut self,
165        f: impl for<'r> FnOnce(&'a mut Box<dyn IoHandler + 'r>) -> T
166    ) -> T {
167        fn _assert_covariance<'x: 'y, 'y>(
168            x: &'y Box<dyn IoHandler + 'x>
169        ) -> &'y Box<dyn IoHandler + 'y> { x }
170
171        f(&mut self.io)
172    }
173
174    pub fn take<'a>(&'a mut self) -> Box<dyn IoHandler + 'a> {
175        fn _assert_covariance<'x: 'y, 'y>(
176            x: &'y Box<dyn IoHandler + 'x>
177        ) -> &'y Box<dyn IoHandler + 'y> { x }
178
179        self.with_inner_mut(|handler| std::mem::replace(handler, Box::new(())))
180    }
181}
182
183impl AsyncRead for ErasedResponse {
184    fn poll_read(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &mut ReadBuf<'_>,
188    ) -> Poll<io::Result<()>> {
189        self.get_mut().with_inner_mut(|r| Pin::new(r.body_mut()).poll_read(cx, buf))
190    }
191}