rocket/shield/
shield.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use crate::{Rocket, Request, Response, Orbit, Config};
5use crate::fairing::{Fairing, Info, Kind};
6use crate::http::{Header, uncased::UncasedStr};
7use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
8use crate::trace::{Trace, TraceAll};
9
10/// A [`Fairing`] that injects browser security and privacy headers into all
11/// outgoing responses.
12///
13/// # Usage
14///
15/// To use `Shield`, first construct an instance of it. To use the default
16/// set of headers, construct with [`Shield::default()`](#method.default).
17/// For an instance with no preset headers, use [`Shield::new()`]. To
18/// enable an additional header, use [`enable()`](Shield::enable()), and to
19/// disable a header, use [`disable()`](Shield::disable()):
20///
21/// ```rust
22/// use rocket::shield::Shield;
23/// use rocket::shield::{XssFilter, ExpectCt};
24///
25/// // A `Shield` with the default headers:
26/// let shield = Shield::default();
27///
28/// // A `Shield` with the default headers minus `XssFilter`:
29/// let shield = Shield::default().disable::<XssFilter>();
30///
31/// // A `Shield` with the default headers plus `ExpectCt`.
32/// let shield = Shield::default().enable(ExpectCt::default());
33///
34/// // A `Shield` with only `XssFilter` and `ExpectCt`.
35/// let shield = Shield::default()
36///     .enable(XssFilter::default())
37///     .enable(ExpectCt::default());
38/// ```
39///
40/// Then, attach the instance of `Shield` to your application's instance of
41/// `Rocket`:
42///
43/// ```rust
44/// # extern crate rocket;
45/// # use rocket::shield::Shield;
46/// # let shield = Shield::default();
47/// rocket::build()
48///     // ...
49///     .attach(shield)
50/// # ;
51/// ```
52///
53/// The fairing will inject all enabled headers into all outgoing responses
54/// _unless_ the response already contains a header with the same name. If it
55/// does contain the header, a warning is emitted, and the header is not
56/// overwritten.
57///
58/// # TLS and HSTS
59///
60/// If TLS is configured and enabled when the application is launched in a
61/// non-debug profile, HSTS is automatically enabled with its default policy and
62/// a warning is logged. To get rid of this warning, explicitly
63/// [`Shield::enable()`] an [`Hsts`] policy.
64pub struct Shield {
65    /// Enabled policies where the key is the header name.
66    policies: HashMap<&'static UncasedStr, Header<'static>>,
67    /// Whether to enforce HSTS even though the user didn't enable it.
68    force_hsts: AtomicBool,
69}
70
71impl Clone for Shield {
72    fn clone(&self) -> Self {
73        Self {
74            policies: self.policies.clone(),
75            force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
76        }
77    }
78}
79
80impl Default for Shield {
81    /// Returns a new `Shield` instance. See the [table] for a description
82    /// of the policies used by default.
83    ///
84    /// [table]: ./#supported-headers
85    ///
86    /// # Example
87    ///
88    /// ```rust
89    /// # extern crate rocket;
90    /// use rocket::shield::Shield;
91    ///
92    /// let shield = Shield::default();
93    /// ```
94    fn default() -> Self {
95        Shield::new()
96            .enable(NoSniff::default())
97            .enable(Frame::default())
98            .enable(Permission::default())
99    }
100}
101
102impl Shield {
103    /// Returns an instance of `Shield` with no headers enabled.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use rocket::shield::Shield;
109    ///
110    /// let shield = Shield::new();
111    /// ```
112    pub fn new() -> Self {
113        Shield {
114            policies: HashMap::new(),
115            force_hsts: AtomicBool::new(false),
116        }
117    }
118
119    /// Enables the policy header `policy`.
120    ///
121    /// If the policy was previously enabled, the configuration is replaced
122    /// with that of `policy`.
123    ///
124    /// # Example
125    ///
126    /// ```rust
127    /// use rocket::shield::Shield;
128    /// use rocket::shield::NoSniff;
129    ///
130    /// let shield = Shield::new().enable(NoSniff::default());
131    /// ```
132    pub fn enable<P: Policy>(mut self, policy: P) -> Self {
133        self.policies.insert(P::NAME.into(), policy.header());
134        self
135    }
136
137    /// Disables the policy header `policy`.
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use rocket::shield::Shield;
143    /// use rocket::shield::NoSniff;
144    ///
145    /// let shield = Shield::default().disable::<NoSniff>();
146    /// ```
147    pub fn disable<P: Policy>(mut self) -> Self {
148        self.policies.remove(UncasedStr::new(P::NAME));
149        self
150    }
151
152    /// Returns `true` if the policy `P` is enabled.
153    ///
154    /// # Example
155    ///
156    /// ```rust
157    /// use rocket::shield::Shield;
158    /// use rocket::shield::{Permission, NoSniff, Frame};
159    /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
160    ///
161    /// let shield = Shield::default();
162    ///
163    /// assert!(shield.is_enabled::<NoSniff>());
164    /// assert!(shield.is_enabled::<Frame>());
165    /// assert!(shield.is_enabled::<Permission>());
166    ///
167    /// assert!(!shield.is_enabled::<Prefetch>());
168    /// assert!(!shield.is_enabled::<ExpectCt>());
169    /// assert!(!shield.is_enabled::<Referrer>());
170    /// ```
171    pub fn is_enabled<P: Policy>(&self) -> bool {
172        self.policies.contains_key(UncasedStr::new(P::NAME))
173    }
174}
175
176#[crate::async_trait]
177impl Fairing for Shield {
178    fn info(&self) -> Info {
179        Info {
180            name: "Shield",
181            kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
182        }
183    }
184
185    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
186        if self.policies.is_empty() {
187            return;
188        }
189
190        let force_hsts = rocket.endpoints().all(|v| v.is_tls())
191            && rocket.figment().profile() != Config::DEBUG_PROFILE
192            && !self.is_enabled::<Hsts>();
193
194        if force_hsts {
195            self.force_hsts.store(true, Ordering::Release);
196        }
197
198        span_info!("shield", policies = self.policies.len() => {
199            self.policies.values().trace_all_info();
200
201            if force_hsts {
202                warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
203                    Shield has enabled a default HSTS policy.\n\
204                    To remove this warning, configure an HSTS policy.");
205            }
206        })
207    }
208
209    async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
210        // Set all of the headers in `self.policies` in `response` as long as
211        // the header is not already in the response.
212        for header in self.policies.values() {
213            if response.headers().contains(header.name()) {
214                span_warn!("shield", "shield refusing to overwrite existing response header" => {
215                    header.trace_warn();
216                });
217
218                continue
219            }
220
221            response.set_header(header.clone());
222        }
223    }
224}