rocket/shield/shield.rs
1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use crate::{Rocket, Request, Response, Orbit, Config};
5use crate::fairing::{Fairing, Info, Kind};
6use crate::http::{Header, uncased::UncasedStr};
7use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
8use crate::trace::{Trace, TraceAll};
9
10/// A [`Fairing`] that injects browser security and privacy headers into all
11/// outgoing responses.
12///
13/// # Usage
14///
15/// To use `Shield`, first construct an instance of it. To use the default
16/// set of headers, construct with [`Shield::default()`](#method.default).
17/// For an instance with no preset headers, use [`Shield::new()`]. To
18/// enable an additional header, use [`enable()`](Shield::enable()), and to
19/// disable a header, use [`disable()`](Shield::disable()):
20///
21/// ```rust
22/// use rocket::shield::Shield;
23/// use rocket::shield::{XssFilter, ExpectCt};
24///
25/// // A `Shield` with the default headers:
26/// let shield = Shield::default();
27///
28/// // A `Shield` with the default headers minus `XssFilter`:
29/// let shield = Shield::default().disable::<XssFilter>();
30///
31/// // A `Shield` with the default headers plus `ExpectCt`.
32/// let shield = Shield::default().enable(ExpectCt::default());
33///
34/// // A `Shield` with only `XssFilter` and `ExpectCt`.
35/// let shield = Shield::default()
36/// .enable(XssFilter::default())
37/// .enable(ExpectCt::default());
38/// ```
39///
40/// Then, attach the instance of `Shield` to your application's instance of
41/// `Rocket`:
42///
43/// ```rust
44/// # extern crate rocket;
45/// # use rocket::shield::Shield;
46/// # let shield = Shield::default();
47/// rocket::build()
48/// // ...
49/// .attach(shield)
50/// # ;
51/// ```
52///
53/// The fairing will inject all enabled headers into all outgoing responses
54/// _unless_ the response already contains a header with the same name. If it
55/// does contain the header, a warning is emitted, and the header is not
56/// overwritten.
57///
58/// # TLS and HSTS
59///
60/// If TLS is configured and enabled when the application is launched in a
61/// non-debug profile, HSTS is automatically enabled with its default policy and
62/// a warning is logged. To get rid of this warning, explicitly
63/// [`Shield::enable()`] an [`Hsts`] policy.
64pub struct Shield {
65 /// Enabled policies where the key is the header name.
66 policies: HashMap<&'static UncasedStr, Header<'static>>,
67 /// Whether to enforce HSTS even though the user didn't enable it.
68 force_hsts: AtomicBool,
69}
70
71impl Clone for Shield {
72 fn clone(&self) -> Self {
73 Self {
74 policies: self.policies.clone(),
75 force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
76 }
77 }
78}
79
80impl Default for Shield {
81 /// Returns a new `Shield` instance. See the [table] for a description
82 /// of the policies used by default.
83 ///
84 /// [table]: ./#supported-headers
85 ///
86 /// # Example
87 ///
88 /// ```rust
89 /// # extern crate rocket;
90 /// use rocket::shield::Shield;
91 ///
92 /// let shield = Shield::default();
93 /// ```
94 fn default() -> Self {
95 Shield::new()
96 .enable(NoSniff::default())
97 .enable(Frame::default())
98 .enable(Permission::default())
99 }
100}
101
102impl Shield {
103 /// Returns an instance of `Shield` with no headers enabled.
104 ///
105 /// # Example
106 ///
107 /// ```rust
108 /// use rocket::shield::Shield;
109 ///
110 /// let shield = Shield::new();
111 /// ```
112 pub fn new() -> Self {
113 Shield {
114 policies: HashMap::new(),
115 force_hsts: AtomicBool::new(false),
116 }
117 }
118
119 /// Enables the policy header `policy`.
120 ///
121 /// If the policy was previously enabled, the configuration is replaced
122 /// with that of `policy`.
123 ///
124 /// # Example
125 ///
126 /// ```rust
127 /// use rocket::shield::Shield;
128 /// use rocket::shield::NoSniff;
129 ///
130 /// let shield = Shield::new().enable(NoSniff::default());
131 /// ```
132 pub fn enable<P: Policy>(mut self, policy: P) -> Self {
133 self.policies.insert(P::NAME.into(), policy.header());
134 self
135 }
136
137 /// Disables the policy header `policy`.
138 ///
139 /// # Example
140 ///
141 /// ```rust
142 /// use rocket::shield::Shield;
143 /// use rocket::shield::NoSniff;
144 ///
145 /// let shield = Shield::default().disable::<NoSniff>();
146 /// ```
147 pub fn disable<P: Policy>(mut self) -> Self {
148 self.policies.remove(UncasedStr::new(P::NAME));
149 self
150 }
151
152 /// Returns `true` if the policy `P` is enabled.
153 ///
154 /// # Example
155 ///
156 /// ```rust
157 /// use rocket::shield::Shield;
158 /// use rocket::shield::{Permission, NoSniff, Frame};
159 /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
160 ///
161 /// let shield = Shield::default();
162 ///
163 /// assert!(shield.is_enabled::<NoSniff>());
164 /// assert!(shield.is_enabled::<Frame>());
165 /// assert!(shield.is_enabled::<Permission>());
166 ///
167 /// assert!(!shield.is_enabled::<Prefetch>());
168 /// assert!(!shield.is_enabled::<ExpectCt>());
169 /// assert!(!shield.is_enabled::<Referrer>());
170 /// ```
171 pub fn is_enabled<P: Policy>(&self) -> bool {
172 self.policies.contains_key(UncasedStr::new(P::NAME))
173 }
174}
175
176#[crate::async_trait]
177impl Fairing for Shield {
178 fn info(&self) -> Info {
179 Info {
180 name: "Shield",
181 kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
182 }
183 }
184
185 async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
186 if self.policies.is_empty() {
187 return;
188 }
189
190 let force_hsts = rocket.endpoints().all(|v| v.is_tls())
191 && rocket.figment().profile() != Config::DEBUG_PROFILE
192 && !self.is_enabled::<Hsts>();
193
194 if force_hsts {
195 self.force_hsts.store(true, Ordering::Release);
196 }
197
198 span_info!("shield", policies = self.policies.len() => {
199 self.policies.values().trace_all_info();
200
201 if force_hsts {
202 warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
203 Shield has enabled a default HSTS policy.\n\
204 To remove this warning, configure an HSTS policy.");
205 }
206 })
207 }
208
209 async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
210 // Set all of the headers in `self.policies` in `response` as long as
211 // the header is not already in the response.
212 for header in self.policies.values() {
213 if response.headers().contains(header.name()) {
214 span_warn!("shield", "shield refusing to overwrite existing response header" => {
215 header.trace_warn();
216 });
217
218 continue
219 }
220
221 response.set_header(header.clone());
222 }
223 }
224}