rocket/state.rs
1use std::fmt;
2use std::ops::Deref;
3use std::any::type_name;
4
5use ref_cast::RefCast;
6
7use crate::{Phase, Rocket, Ignite, Sentinel};
8use crate::request::{self, FromRequest, Request};
9use crate::outcome::Outcome;
10use crate::http::Status;
11
12/// Request guard to retrieve managed state.
13///
14/// A reference `&State<T>` type is a request guard which retrieves the managed
15/// state managing for some type `T`. A value for the given type must previously
16/// have been registered to be managed by Rocket via [`Rocket::manage()`]. The
17/// type being managed must be thread safe and sendable across thread
18/// boundaries as multiple handlers in multiple threads may be accessing the
19/// value at once. In other words, it must implement [`Send`] + [`Sync`] +
20/// `'static`.
21///
22/// # Example
23///
24/// Imagine you have some configuration struct of the type `MyConfig` that you'd
25/// like to initialize at start-up and later access it in several handlers. The
26/// following example does just this:
27///
28/// ```rust,no_run
29/// # #[macro_use] extern crate rocket;
30/// use rocket::State;
31///
32/// // In a real application, this would likely be more complex.
33/// struct MyConfig {
34/// user_val: String
35/// }
36///
37/// #[get("/")]
38/// fn index(state: &State<MyConfig>) -> String {
39/// format!("The config value is: {}", state.user_val)
40/// }
41///
42/// #[get("/raw")]
43/// fn raw_config_value(state: &State<MyConfig>) -> &str {
44/// &state.user_val
45/// }
46///
47/// #[launch]
48/// fn rocket() -> _ {
49/// rocket::build()
50/// .mount("/", routes![index, raw_config_value])
51/// .manage(MyConfig { user_val: "user input".to_string() })
52/// }
53/// ```
54///
55/// # Within Request Guards
56///
57/// Because `State` is itself a request guard, managed state can be retrieved
58/// from another request guard's implementation using either
59/// [`Request::guard()`] or [`Rocket::state()`]. In the following code example,
60/// the `Item` request guard retrieves `MyConfig` from managed state:
61///
62/// ```rust
63/// use rocket::State;
64/// use rocket::request::{self, Request, FromRequest};
65/// use rocket::outcome::IntoOutcome;
66/// use rocket::http::Status;
67///
68/// # struct MyConfig { user_val: String };
69/// struct Item<'r>(&'r str);
70///
71/// #[rocket::async_trait]
72/// impl<'r> FromRequest<'r> for Item<'r> {
73/// type Error = ();
74///
75/// async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
76/// // Using `State` as a request guard. Use `inner()` to get an `'r`.
77/// let outcome = request.guard::<&State<MyConfig>>().await
78/// .map(|my_config| Item(&my_config.user_val));
79///
80/// // Or alternatively, using `Rocket::state()`:
81/// let outcome = request.rocket().state::<MyConfig>()
82/// .map(|my_config| Item(&my_config.user_val))
83/// .or_forward(Status::InternalServerError);
84///
85/// outcome
86/// }
87/// }
88/// ```
89///
90/// # Testing with `State`
91///
92/// When unit testing your application, you may find it necessary to manually
93/// construct a type of `State` to pass to your functions. To do so, use the
94/// [`State::get()`] static method or the `From<&T>` implementation:
95///
96/// ```rust
97/// # #[macro_use] extern crate rocket;
98/// use rocket::State;
99///
100/// struct MyManagedState(usize);
101///
102/// #[get("/")]
103/// fn handler(state: &State<MyManagedState>) -> String {
104/// state.0.to_string()
105/// }
106///
107/// let mut rocket = rocket::build().manage(MyManagedState(127));
108/// let state = State::get(&rocket).expect("managed `MyManagedState`");
109/// assert_eq!(handler(state), "127");
110///
111/// let managed = MyManagedState(77);
112/// assert_eq!(handler(State::from(&managed)), "77");
113/// ```
114#[repr(transparent)]
115#[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
116pub struct State<T: Send + Sync + 'static>(T);
117
118impl<T: Send + Sync + 'static> State<T> {
119 /// Returns the managed state value in `rocket` for the type `T` if it is
120 /// being managed by `rocket`. Otherwise, returns `None`.
121 ///
122 /// # Example
123 ///
124 /// ```rust
125 /// use rocket::State;
126 ///
127 /// #[derive(Debug, PartialEq)]
128 /// struct Managed(usize);
129 ///
130 /// #[derive(Debug, PartialEq)]
131 /// struct Unmanaged(usize);
132 ///
133 /// let rocket = rocket::build().manage(Managed(7));
134 ///
135 /// let state: Option<&State<Managed>> = State::get(&rocket);
136 /// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7)));
137 ///
138 /// let state: Option<&State<Unmanaged>> = State::get(&rocket);
139 /// assert_eq!(state, None);
140 /// ```
141 #[inline(always)]
142 pub fn get<P: Phase>(rocket: &Rocket<P>) -> Option<&State<T>> {
143 rocket.state::<T>().map(State::ref_cast)
144 }
145
146 /// This exists because `State::from()` would otherwise be nothing. But we
147 /// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is.
148 #[doc(hidden)]
149 #[inline(always)]
150 pub fn from(value: &T) -> &State<T> {
151 State::ref_cast(value)
152 }
153
154 /// Borrow the inner value.
155 ///
156 /// Using this method is typically unnecessary as `State` implements
157 /// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
158 /// automatically coerce a `State<T>` to an `&T` as required. This method
159 /// should only be used when a longer lifetime is required.
160 ///
161 /// # Example
162 ///
163 /// ```rust
164 /// use rocket::State;
165 ///
166 /// #[derive(Clone)]
167 /// struct MyConfig {
168 /// user_val: String
169 /// }
170 ///
171 /// fn handler1<'r>(config: &State<MyConfig>) -> String {
172 /// let config = config.inner().clone();
173 /// config.user_val
174 /// }
175 ///
176 /// // Use the `Deref` implementation which coerces implicitly
177 /// fn handler2(config: &State<MyConfig>) -> String {
178 /// config.user_val.clone()
179 /// }
180 /// ```
181 #[inline(always)]
182 pub fn inner(&self) -> &T {
183 &self.0
184 }
185}
186
187impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State<T> {
188 #[inline(always)]
189 fn from(reference: &'r T) -> Self {
190 State::ref_cast(reference)
191 }
192}
193
194#[crate::async_trait]
195impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State<T> {
196 type Error = ();
197
198 #[inline(always)]
199 async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
200 match State::get(req.rocket()) {
201 Some(state) => Outcome::Success(state),
202 None => {
203 error!(type_name = type_name::<T>(),
204 "retrieving unmanaged state\n\
205 state must be managed via `rocket.manage()`");
206
207 Outcome::Error((Status::InternalServerError, ()))
208 }
209 }
210 }
211}
212
213impl<T: Send + Sync + 'static> Sentinel for &State<T> {
214 fn abort(rocket: &Rocket<Ignite>) -> bool {
215 if rocket.state::<T>().is_none() {
216 error!(type_name = type_name::<T>(),
217 "unmanaged state detected\n\
218 ensure type is being managed via `rocket.manage()`");
219
220 return true;
221 }
222
223 false
224 }
225}
226
227impl<T: Send + Sync + fmt::Display + 'static> fmt::Display for State<T> {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 self.0.fmt(f)
230 }
231}
232
233impl<T: Send + Sync + 'static> Deref for State<T> {
234 type Target = T;
235
236 #[inline(always)]
237 fn deref(&self) -> &T {
238 &self.0
239 }
240}