rocket/shield/shield.rs
1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use state::InitCell;
5use yansi::Paint;
6
7use crate::{Rocket, Request, Response, Orbit, Config};
8use crate::fairing::{Fairing, Info, Kind};
9use crate::http::{Header, uncased::UncasedStr};
10use crate::log::PaintExt;
11use crate::shield::*;
12
13/// A [`Fairing`] that injects browser security and privacy headers into all
14/// outgoing responses.
15///
16/// # Usage
17///
18/// To use `Shield`, first construct an instance of it. To use the default
19/// set of headers, construct with [`Shield::default()`](#method.default).
20/// For an instance with no preset headers, use [`Shield::new()`]. To
21/// enable an additional header, use [`enable()`](Shield::enable()), and to
22/// disable a header, use [`disable()`](Shield::disable()):
23///
24/// ```rust
25/// use rocket::shield::Shield;
26/// use rocket::shield::{XssFilter, ExpectCt};
27///
28/// // A `Shield` with the default headers:
29/// let shield = Shield::default();
30///
31/// // A `Shield` with the default headers minus `XssFilter`:
32/// let shield = Shield::default().disable::<XssFilter>();
33///
34/// // A `Shield` with the default headers plus `ExpectCt`.
35/// let shield = Shield::default().enable(ExpectCt::default());
36///
37/// // A `Shield` with only `XssFilter` and `ExpectCt`.
38/// let shield = Shield::default()
39/// .enable(XssFilter::default())
40/// .enable(ExpectCt::default());
41/// ```
42///
43/// Then, attach the instance of `Shield` to your application's instance of
44/// `Rocket`:
45///
46/// ```rust
47/// # extern crate rocket;
48/// # use rocket::shield::Shield;
49/// # let shield = Shield::default();
50/// rocket::build()
51/// // ...
52/// .attach(shield)
53/// # ;
54/// ```
55///
56/// The fairing will inject all enabled headers into all outgoing responses
57/// _unless_ the response already contains a header with the same name. If it
58/// does contain the header, a warning is emitted, and the header is not
59/// overwritten.
60///
61/// # TLS and HSTS
62///
63/// If TLS is configured and enabled when the application is launched in a
64/// non-debug profile, HSTS is automatically enabled with its default policy and
65/// a warning is logged.
66///
67/// To get rid of this warning, explicitly [`Shield::enable()`] an [`Hsts`]
68/// policy.
69pub struct Shield {
70 /// Enabled policies where the key is the header name.
71 policies: HashMap<&'static UncasedStr, Box<dyn SubPolicy>>,
72 /// Whether to enforce HSTS even though the user didn't enable it.
73 force_hsts: AtomicBool,
74 /// Headers pre-rendered at liftoff from the configured policies.
75 rendered: InitCell<Vec<Header<'static>>>,
76}
77
78impl Default for Shield {
79 /// Returns a new `Shield` instance. See the [table] for a description
80 /// of the policies used by default.
81 ///
82 /// [table]: ./#supported-headers
83 ///
84 /// # Example
85 ///
86 /// ```rust
87 /// # extern crate rocket;
88 /// use rocket::shield::Shield;
89 ///
90 /// let shield = Shield::default();
91 /// ```
92 fn default() -> Self {
93 Shield::new()
94 .enable(NoSniff::default())
95 .enable(Frame::default())
96 .enable(Permission::default())
97 }
98}
99
100impl Shield {
101 /// Returns an instance of `Shield` with no headers enabled.
102 ///
103 /// # Example
104 ///
105 /// ```rust
106 /// use rocket::shield::Shield;
107 ///
108 /// let shield = Shield::new();
109 /// ```
110 pub fn new() -> Self {
111 Shield {
112 policies: HashMap::new(),
113 force_hsts: AtomicBool::new(false),
114 rendered: InitCell::new(),
115 }
116 }
117
118 /// Enables the policy header `policy`.
119 ///
120 /// If the policy was previously enabled, the configuration is replaced
121 /// with that of `policy`.
122 ///
123 /// # Example
124 ///
125 /// ```rust
126 /// use rocket::shield::Shield;
127 /// use rocket::shield::NoSniff;
128 ///
129 /// let shield = Shield::new().enable(NoSniff::default());
130 /// ```
131 pub fn enable<P: Policy>(mut self, policy: P) -> Self {
132 self.rendered = InitCell::new();
133 self.policies.insert(P::NAME.into(), Box::new(policy));
134 self
135 }
136
137 /// Disables the policy header `policy`.
138 ///
139 /// # Example
140 ///
141 /// ```rust
142 /// use rocket::shield::Shield;
143 /// use rocket::shield::NoSniff;
144 ///
145 /// let shield = Shield::default().disable::<NoSniff>();
146 /// ```
147 pub fn disable<P: Policy>(mut self) -> Self {
148 self.rendered = InitCell::new();
149 self.policies.remove(UncasedStr::new(P::NAME));
150 self
151 }
152
153 /// Returns `true` if the policy `P` is enabled.
154 ///
155 /// # Example
156 ///
157 /// ```rust
158 /// use rocket::shield::Shield;
159 /// use rocket::shield::{Permission, NoSniff, Frame};
160 /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
161 ///
162 /// let shield = Shield::default();
163 ///
164 /// assert!(shield.is_enabled::<NoSniff>());
165 /// assert!(shield.is_enabled::<Frame>());
166 /// assert!(shield.is_enabled::<Permission>());
167 ///
168 /// assert!(!shield.is_enabled::<Prefetch>());
169 /// assert!(!shield.is_enabled::<ExpectCt>());
170 /// assert!(!shield.is_enabled::<Referrer>());
171 /// ```
172 pub fn is_enabled<P: Policy>(&self) -> bool {
173 self.policies.contains_key(UncasedStr::new(P::NAME))
174 }
175
176 fn headers(&self) -> &[Header<'static>] {
177 self.rendered.get_or_init(|| {
178 let mut headers: Vec<_> = self.policies.values()
179 .map(|p| p.header())
180 .collect();
181
182 if self.force_hsts.load(Ordering::Acquire) {
183 headers.push(Policy::header(&Hsts::default()));
184 }
185
186 headers
187 })
188 }
189}
190
191#[crate::async_trait]
192impl Fairing for Shield {
193 fn info(&self) -> Info {
194 Info {
195 name: "Shield",
196 kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
197 }
198 }
199
200 async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
201 let force_hsts = rocket.config().tls_enabled()
202 && rocket.figment().profile() != Config::DEBUG_PROFILE
203 && !self.is_enabled::<Hsts>();
204
205 if force_hsts {
206 self.force_hsts.store(true, Ordering::Release);
207 }
208
209 if !self.headers().is_empty() {
210 info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta());
211
212 for header in self.headers() {
213 info_!("{}: {}", header.name(), header.value().primary());
214 }
215
216 if force_hsts {
217 warn_!("Detected TLS-enabled liftoff without enabling HSTS.");
218 warn_!("Shield has enabled a default HSTS policy.");
219 info_!("To remove this warning, configure an HSTS policy.");
220 }
221 }
222 }
223
224 async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
225 // Set all of the headers in `self.policies` in `response` as long as
226 // the header is not already in the response.
227 for header in self.headers() {
228 if response.headers().contains(header.name()) {
229 warn!("Shield: response contains a '{}' header.", header.name());
230 warn_!("Refusing to overwrite existing header.");
231 continue
232 }
233
234 response.set_header(header.clone());
235 }
236 }
237}