rocket/local/asynchronous/
response.rs

1use std::io;
2use std::future::Future;
3use std::{pin::Pin, task::{Context, Poll}};
4
5use tokio::io::{AsyncRead, ReadBuf};
6
7use crate::http::CookieJar;
8use crate::{Request, Response};
9
10/// An `async` response from a dispatched [`LocalRequest`](super::LocalRequest).
11///
12/// This `LocalResponse` implements [`tokio::io::AsyncRead`]. As such, if
13/// [`into_string()`](LocalResponse::into_string()) and
14/// [`into_bytes()`](LocalResponse::into_bytes()) do not suffice, the response's
15/// body can be read directly:
16///
17/// ```rust
18/// # #[macro_use] extern crate rocket;
19/// use std::io;
20///
21/// use rocket::local::asynchronous::Client;
22/// use rocket::tokio::io::AsyncReadExt;
23/// use rocket::http::Status;
24///
25/// #[get("/")]
26/// fn hello_world() -> &'static str {
27///     "Hello, world!"
28/// }
29///
30/// #[launch]
31/// fn rocket() -> _ {
32///     rocket::build().mount("/", routes![hello_world])
33///     #    .configure(rocket::Config::debug_default())
34/// }
35///
36/// # async fn read_body_manually() -> io::Result<()> {
37/// // Dispatch a `GET /` request.
38/// let client = Client::tracked(rocket()).await.expect("valid rocket");
39/// let mut response = client.get("/").dispatch().await;
40///
41/// // Check metadata validity.
42/// assert_eq!(response.status(), Status::Ok);
43/// assert_eq!(response.body().preset_size(), Some(13));
44///
45/// // Read 10 bytes of the body. Note: in reality, we'd use `into_string()`.
46/// let mut buffer = [0; 10];
47/// response.read(&mut buffer).await?;
48/// assert_eq!(buffer, "Hello, wor".as_bytes());
49/// # Ok(())
50/// # }
51/// # rocket::async_test(read_body_manually()).expect("read okay");
52/// ```
53///
54/// For more, see [the top-level documentation](../index.html#localresponse).
55pub struct LocalResponse<'c> {
56    _request: Box<Request<'c>>,
57    response: Response<'c>,
58    cookies: CookieJar<'c>,
59}
60
61impl<'c> LocalResponse<'c> {
62    pub(crate) fn new<F, O>(req: Request<'c>, f: F) -> impl Future<Output = LocalResponse<'c>>
63        where F: FnOnce(&'c Request<'c>) -> O + Send,
64              O: Future<Output = Response<'c>> + Send
65    {
66        // `LocalResponse` is a self-referential structure. In particular,
67        // `inner` can refer to `_request` and its contents. As such, we must
68        //   1) Ensure `Request` has a stable address.
69        //
70        //      This is done by `Box`ing the `Request`, using only the stable
71        //      address thereafter.
72        //
73        //   2) Ensure no refs to `Request` or its contents leak with a lifetime
74        //      extending beyond that of `&self`.
75        //
76        //      We have no methods that return an `&Request`. However, we must
77        //      also ensure that `Response` doesn't leak any such references. To
78        //      do so, we don't expose the `Response` directly in any way;
79        //      otherwise, methods like `.headers()` could, in conjunction with
80        //      particular crafted `Responder`s, potentially be used to obtain a
81        //      reference to contents of `Request`. All methods, instead, return
82        //      references bounded by `self`. This is easily verified by noting
83        //      that 1) `LocalResponse` fields are private, and 2) all `impl`s
84        //      of `LocalResponse` aside from this method abstract the lifetime
85        //      away as `'_`, ensuring it is not used for any output value.
86        let boxed_req = Box::new(req);
87        let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) };
88
89        async move {
90            let response: Response<'c> = f(request).await;
91            let mut cookies = CookieJar::new(request.rocket().config());
92            for cookie in response.cookies() {
93                cookies.add_original(cookie.into_owned());
94            }
95
96            LocalResponse { cookies, _request: boxed_req, response, }
97        }
98    }
99}
100
101impl LocalResponse<'_> {
102    pub(crate) fn _response(&self) -> &Response<'_> {
103        &self.response
104    }
105
106    pub(crate) fn _cookies(&self) -> &CookieJar<'_> {
107        &self.cookies
108    }
109
110    pub(crate) async fn _into_string(mut self) -> io::Result<String> {
111        self.response.body_mut().to_string().await
112    }
113
114    pub(crate) async fn _into_bytes(mut self) -> io::Result<Vec<u8>> {
115        self.response.body_mut().to_bytes().await
116    }
117
118    #[cfg(feature = "json")]
119    async fn _into_json<T: Send + 'static>(self) -> Option<T>
120        where T: serde::de::DeserializeOwned
121    {
122        self.blocking_read(|r| serde_json::from_reader(r)).await?.ok()
123    }
124
125    #[cfg(feature = "msgpack")]
126    async fn _into_msgpack<T: Send + 'static>(self) -> Option<T>
127        where T: serde::de::DeserializeOwned
128    {
129        self.blocking_read(|r| rmp_serde::from_read(r)).await?.ok()
130    }
131
132    #[cfg(any(feature = "json", feature = "msgpack"))]
133    async fn blocking_read<T, F>(mut self, f: F) -> Option<T>
134        where T: Send + 'static,
135              F: FnOnce(&mut dyn io::Read) -> T + Send + 'static
136    {
137        use tokio::sync::mpsc;
138        use tokio::io::AsyncReadExt;
139
140        struct ChanReader {
141            last: Option<io::Cursor<Vec<u8>>>,
142            rx: mpsc::Receiver<io::Result<Vec<u8>>>,
143        }
144
145        impl std::io::Read for ChanReader {
146            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
147                loop {
148                    if let Some(ref mut cursor) = self.last {
149                        if cursor.position() < cursor.get_ref().len() as u64 {
150                            return std::io::Read::read(cursor, buf);
151                        }
152                    }
153
154                    if let Some(buf) = self.rx.blocking_recv() {
155                        self.last = Some(io::Cursor::new(buf?));
156                    } else {
157                        return Ok(0);
158                    }
159                }
160            }
161        }
162
163        let (tx, rx) = mpsc::channel(2);
164        let reader = tokio::task::spawn_blocking(move || {
165            let mut reader = ChanReader { last: None, rx };
166            f(&mut reader)
167        });
168
169        loop {
170            // TODO: Try to fill as much as the buffer before send it off?
171            let mut buf = Vec::with_capacity(1024);
172            match self.read_buf(&mut buf).await {
173                Ok(n) if n == 0 => break,
174                Ok(_) => tx.send(Ok(buf)).await.ok()?,
175                Err(e) => {
176                    tx.send(Err(e)).await.ok()?;
177                    break;
178                }
179            }
180        }
181
182        // NOTE: We _must_ drop tx now to prevent a deadlock!
183        drop(tx);
184
185        reader.await.ok()
186    }
187
188    // Generates the public API methods, which call the private methods above.
189    pub_response_impl!("# use rocket::local::asynchronous::Client;\n\
190        use rocket::local::asynchronous::LocalResponse;" async await);
191}
192
193impl AsyncRead for LocalResponse<'_> {
194    fn poll_read(
195        mut self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &mut ReadBuf<'_>,
198    ) -> Poll<io::Result<()>> {
199        Pin::new(self.response.body_mut()).poll_read(cx, buf)
200    }
201}
202
203impl std::fmt::Debug for LocalResponse<'_> {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        self._response().fmt(f)
206    }
207}