rocket/shield/
shield.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use state::InitCell;
5use yansi::Paint;
6
7use crate::{Rocket, Request, Response, Orbit, Config};
8use crate::fairing::{Fairing, Info, Kind};
9use crate::http::{Header, uncased::UncasedStr};
10use crate::log::PaintExt;
11use crate::shield::*;
12
13/// A [`Fairing`] that injects browser security and privacy headers into all
14/// outgoing responses.
15///
16/// # Usage
17///
18/// To use `Shield`, first construct an instance of it. To use the default
19/// set of headers, construct with [`Shield::default()`](#method.default).
20/// For an instance with no preset headers, use [`Shield::new()`]. To
21/// enable an additional header, use [`enable()`](Shield::enable()), and to
22/// disable a header, use [`disable()`](Shield::disable()):
23///
24/// ```rust
25/// use rocket::shield::Shield;
26/// use rocket::shield::{XssFilter, ExpectCt};
27///
28/// // A `Shield` with the default headers:
29/// let shield = Shield::default();
30///
31/// // A `Shield` with the default headers minus `XssFilter`:
32/// let shield = Shield::default().disable::<XssFilter>();
33///
34/// // A `Shield` with the default headers plus `ExpectCt`.
35/// let shield = Shield::default().enable(ExpectCt::default());
36///
37/// // A `Shield` with only `XssFilter` and `ExpectCt`.
38/// let shield = Shield::default()
39///     .enable(XssFilter::default())
40///     .enable(ExpectCt::default());
41/// ```
42///
43/// Then, attach the instance of `Shield` to your application's instance of
44/// `Rocket`:
45///
46/// ```rust
47/// # extern crate rocket;
48/// # use rocket::shield::Shield;
49/// # let shield = Shield::default();
50/// rocket::build()
51///     // ...
52///     .attach(shield)
53/// # ;
54/// ```
55///
56/// The fairing will inject all enabled headers into all outgoing responses
57/// _unless_ the response already contains a header with the same name. If it
58/// does contain the header, a warning is emitted, and the header is not
59/// overwritten.
60///
61/// # TLS and HSTS
62///
63/// If TLS is configured and enabled when the application is launched in a
64/// non-debug profile, HSTS is automatically enabled with its default policy and
65/// a warning is logged.
66///
67/// To get rid of this warning, explicitly [`Shield::enable()`] an [`Hsts`]
68/// policy.
69pub struct Shield {
70    /// Enabled policies where the key is the header name.
71    policies: HashMap<&'static UncasedStr, Box<dyn SubPolicy>>,
72    /// Whether to enforce HSTS even though the user didn't enable it.
73    force_hsts: AtomicBool,
74    /// Headers pre-rendered at liftoff from the configured policies.
75    rendered: InitCell<Vec<Header<'static>>>,
76}
77
78impl Default for Shield {
79    /// Returns a new `Shield` instance. See the [table] for a description
80    /// of the policies used by default.
81    ///
82    /// [table]: ./#supported-headers
83    ///
84    /// # Example
85    ///
86    /// ```rust
87    /// # extern crate rocket;
88    /// use rocket::shield::Shield;
89    ///
90    /// let shield = Shield::default();
91    /// ```
92    fn default() -> Self {
93        Shield::new()
94            .enable(NoSniff::default())
95            .enable(Frame::default())
96            .enable(Permission::default())
97    }
98}
99
100impl Shield {
101    /// Returns an instance of `Shield` with no headers enabled.
102    ///
103    /// # Example
104    ///
105    /// ```rust
106    /// use rocket::shield::Shield;
107    ///
108    /// let shield = Shield::new();
109    /// ```
110    pub fn new() -> Self {
111        Shield {
112            policies: HashMap::new(),
113            force_hsts: AtomicBool::new(false),
114            rendered: InitCell::new(),
115        }
116    }
117
118    /// Enables the policy header `policy`.
119    ///
120    /// If the policy was previously enabled, the configuration is replaced
121    /// with that of `policy`.
122    ///
123    /// # Example
124    ///
125    /// ```rust
126    /// use rocket::shield::Shield;
127    /// use rocket::shield::NoSniff;
128    ///
129    /// let shield = Shield::new().enable(NoSniff::default());
130    /// ```
131    pub fn enable<P: Policy>(mut self, policy: P) -> Self {
132        self.rendered = InitCell::new();
133        self.policies.insert(P::NAME.into(), Box::new(policy));
134        self
135    }
136
137    /// Disables the policy header `policy`.
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use rocket::shield::Shield;
143    /// use rocket::shield::NoSniff;
144    ///
145    /// let shield = Shield::default().disable::<NoSniff>();
146    /// ```
147    pub fn disable<P: Policy>(mut self) -> Self {
148        self.rendered = InitCell::new();
149        self.policies.remove(UncasedStr::new(P::NAME));
150        self
151    }
152
153    /// Returns `true` if the policy `P` is enabled.
154    ///
155    /// # Example
156    ///
157    /// ```rust
158    /// use rocket::shield::Shield;
159    /// use rocket::shield::{Permission, NoSniff, Frame};
160    /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
161    ///
162    /// let shield = Shield::default();
163    ///
164    /// assert!(shield.is_enabled::<NoSniff>());
165    /// assert!(shield.is_enabled::<Frame>());
166    /// assert!(shield.is_enabled::<Permission>());
167    ///
168    /// assert!(!shield.is_enabled::<Prefetch>());
169    /// assert!(!shield.is_enabled::<ExpectCt>());
170    /// assert!(!shield.is_enabled::<Referrer>());
171    /// ```
172    pub fn is_enabled<P: Policy>(&self) -> bool {
173        self.policies.contains_key(UncasedStr::new(P::NAME))
174    }
175
176    fn headers(&self) -> &[Header<'static>] {
177        self.rendered.get_or_init(|| {
178            let mut headers: Vec<_> = self.policies.values()
179                .map(|p| p.header())
180                .collect();
181
182            if self.force_hsts.load(Ordering::Acquire) {
183                headers.push(Policy::header(&Hsts::default()));
184            }
185
186            headers
187        })
188    }
189}
190
191#[crate::async_trait]
192impl Fairing for Shield {
193    fn info(&self) -> Info {
194        Info {
195            name: "Shield",
196            kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
197        }
198    }
199
200    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
201        let force_hsts = rocket.config().tls_enabled()
202            && rocket.figment().profile() != Config::DEBUG_PROFILE
203            && !self.is_enabled::<Hsts>();
204
205        if force_hsts {
206            self.force_hsts.store(true, Ordering::Release);
207        }
208
209        if !self.headers().is_empty() {
210            info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta());
211
212            for header in self.headers() {
213                info_!("{}: {}", header.name(), header.value().primary());
214            }
215
216            if force_hsts {
217                warn_!("Detected TLS-enabled liftoff without enabling HSTS.");
218                warn_!("Shield has enabled a default HSTS policy.");
219                info_!("To remove this warning, configure an HSTS policy.");
220            }
221        }
222    }
223
224    async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
225        // Set all of the headers in `self.policies` in `response` as long as
226        // the header is not already in the response.
227        for header in self.headers() {
228            if response.headers().contains(header.name()) {
229                warn!("Shield: response contains a '{}' header.", header.name());
230                warn_!("Refusing to overwrite existing header.");
231                continue
232            }
233
234            response.set_header(header.clone());
235        }
236    }
237}