rocket/
state.rs

1use std::fmt;
2use std::ops::Deref;
3use std::any::type_name;
4
5use ref_cast::RefCast;
6use yansi::Paint;
7
8use crate::{Phase, Rocket, Ignite, Sentinel};
9use crate::request::{self, FromRequest, Request};
10use crate::outcome::Outcome;
11use crate::http::Status;
12
13/// Request guard to retrieve managed state.
14///
15/// A reference `&State<T>` type is a request guard which retrieves the managed
16/// state managing for some type `T`. A value for the given type must previously
17/// have been registered to be managed by Rocket via [`Rocket::manage()`]. The
18/// type being managed must be thread safe and sendable across thread
19/// boundaries as multiple handlers in multiple threads may be accessing the
20/// value at once. In other words, it must implement [`Send`] + [`Sync`] +
21/// `'static`.
22///
23/// # Example
24///
25/// Imagine you have some configuration struct of the type `MyConfig` that you'd
26/// like to initialize at start-up and later access it in several handlers. The
27/// following example does just this:
28///
29/// ```rust,no_run
30/// # #[macro_use] extern crate rocket;
31/// use rocket::State;
32///
33/// // In a real application, this would likely be more complex.
34/// struct MyConfig {
35///     user_val: String
36/// }
37///
38/// #[get("/")]
39/// fn index(state: &State<MyConfig>) -> String {
40///     format!("The config value is: {}", state.user_val)
41/// }
42///
43/// #[get("/raw")]
44/// fn raw_config_value(state: &State<MyConfig>) -> &str {
45///     &state.user_val
46/// }
47///
48/// #[launch]
49/// fn rocket() -> _ {
50///     rocket::build()
51///         .mount("/", routes![index, raw_config_value])
52///         .manage(MyConfig { user_val: "user input".to_string() })
53/// }
54/// ```
55///
56/// # Within Request Guards
57///
58/// Because `State` is itself a request guard, managed state can be retrieved
59/// from another request guard's implementation using either
60/// [`Request::guard()`] or [`Rocket::state()`]. In the following code example,
61/// the `Item` request guard retrieves `MyConfig` from managed state:
62///
63/// ```rust
64/// use rocket::State;
65/// use rocket::request::{self, Request, FromRequest};
66/// use rocket::outcome::IntoOutcome;
67/// use rocket::http::Status;
68///
69/// # struct MyConfig { user_val: String };
70/// struct Item<'r>(&'r str);
71///
72/// #[rocket::async_trait]
73/// impl<'r> FromRequest<'r> for Item<'r> {
74///     type Error = ();
75///
76///     async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
77///         // Using `State` as a request guard. Use `inner()` to get an `'r`.
78///         let outcome = request.guard::<&State<MyConfig>>().await
79///             .map(|my_config| Item(&my_config.user_val));
80///
81///         // Or alternatively, using `Rocket::state()`:
82///         let outcome = request.rocket().state::<MyConfig>()
83///             .map(|my_config| Item(&my_config.user_val))
84///             .or_forward(Status::InternalServerError);
85///
86///         outcome
87///     }
88/// }
89/// ```
90///
91/// # Testing with `State`
92///
93/// When unit testing your application, you may find it necessary to manually
94/// construct a type of `State` to pass to your functions. To do so, use the
95/// [`State::get()`] static method or the `From<&T>` implementation:
96///
97/// ```rust
98/// # #[macro_use] extern crate rocket;
99/// use rocket::State;
100///
101/// struct MyManagedState(usize);
102///
103/// #[get("/")]
104/// fn handler(state: &State<MyManagedState>) -> String {
105///     state.0.to_string()
106/// }
107///
108/// let mut rocket = rocket::build().manage(MyManagedState(127));
109/// let state = State::get(&rocket).expect("managed `MyManagedState`");
110/// assert_eq!(handler(state), "127");
111///
112/// let managed = MyManagedState(77);
113/// assert_eq!(handler(State::from(&managed)), "77");
114/// ```
115#[repr(transparent)]
116#[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
117pub struct State<T: Send + Sync + 'static>(T);
118
119impl<T: Send + Sync + 'static> State<T> {
120    /// Returns the managed state value in `rocket` for the type `T` if it is
121    /// being managed by `rocket`. Otherwise, returns `None`.
122    ///
123    /// # Example
124    ///
125    /// ```rust
126    /// use rocket::State;
127    ///
128    /// #[derive(Debug, PartialEq)]
129    /// struct Managed(usize);
130    ///
131    /// #[derive(Debug, PartialEq)]
132    /// struct Unmanaged(usize);
133    ///
134    /// let rocket = rocket::build().manage(Managed(7));
135    ///
136    /// let state: Option<&State<Managed>> = State::get(&rocket);
137    /// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7)));
138    ///
139    /// let state: Option<&State<Unmanaged>> = State::get(&rocket);
140    /// assert_eq!(state, None);
141    /// ```
142    #[inline(always)]
143    pub fn get<P: Phase>(rocket: &Rocket<P>) -> Option<&State<T>> {
144        rocket.state::<T>().map(State::ref_cast)
145    }
146
147    /// This exists because `State::from()` would otherwise be nothing. But we
148    /// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is.
149    #[doc(hidden)]
150    #[inline(always)]
151    pub fn from(value: &T) -> &State<T> {
152        State::ref_cast(value)
153    }
154
155    /// Borrow the inner value.
156    ///
157    /// Using this method is typically unnecessary as `State` implements
158    /// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
159    /// automatically coerce a `State<T>` to an `&T` as required. This method
160    /// should only be used when a longer lifetime is required.
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// use rocket::State;
166    ///
167    /// #[derive(Clone)]
168    /// struct MyConfig {
169    ///     user_val: String
170    /// }
171    ///
172    /// fn handler1<'r>(config: &State<MyConfig>) -> String {
173    ///     let config = config.inner().clone();
174    ///     config.user_val
175    /// }
176    ///
177    /// // Use the `Deref` implementation which coerces implicitly
178    /// fn handler2(config: &State<MyConfig>) -> String {
179    ///     config.user_val.clone()
180    /// }
181    /// ```
182    #[inline(always)]
183    pub fn inner(&self) -> &T {
184        &self.0
185    }
186}
187
188impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State<T> {
189    #[inline(always)]
190    fn from(reference: &'r T) -> Self {
191        State::ref_cast(reference)
192    }
193}
194
195#[crate::async_trait]
196impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State<T> {
197    type Error = ();
198
199    #[inline(always)]
200    async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
201        match State::get(req.rocket()) {
202            Some(state) => Outcome::Success(state),
203            None => {
204                error_!("Attempted to retrieve unmanaged state `{}`!", type_name::<T>());
205                Outcome::Error((Status::InternalServerError, ()))
206            }
207        }
208    }
209}
210
211impl<T: Send + Sync + 'static> Sentinel for &State<T> {
212    fn abort(rocket: &Rocket<Ignite>) -> bool {
213        if rocket.state::<T>().is_none() {
214            let type_name = type_name::<T>();
215            error!("launching with unmanaged `{}` state.", type_name.primary().bold());
216            info_!("Using `State` requires managing it with `.manage()`.");
217            return true;
218        }
219
220        false
221    }
222}
223
224impl<T: Send + Sync + fmt::Display + 'static> fmt::Display for State<T> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        self.0.fmt(f)
227    }
228}
229
230impl<T: Send + Sync + 'static> Deref for State<T> {
231    type Target = T;
232
233    #[inline(always)]
234    fn deref(&self) -> &T {
235        &self.0
236    }
237}