Skip to content

Commit 61cd326

Browse files
committed
Fix safety issues & comments
- The previous `on_request` was unsafe, now uses a separate `filter_request` method to ensure `error` only contains immutable borrows of `Request`. - Added full safety comments and arguements to `erased`. - Added extra logic to upgrade, to prevent upgrading if an error has been thrown.
1 parent 61a4b44 commit 61cd326

File tree

9 files changed

+270
-145
lines changed

9 files changed

+270
-145
lines changed

core/lib/src/erased.rs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,40 @@ impl Drop for ErasedRequest {
3535
fn drop(&mut self) { }
3636
}
3737

38+
pub struct ErasedError<'r> {
39+
error: Option<Pin<Box<dyn TypedError<'r> + 'r>>>,
40+
}
41+
42+
impl<'r> ErasedError<'r> {
43+
pub fn new() -> Self {
44+
Self { error: None }
45+
}
46+
47+
pub fn write(&mut self, error: Option<Box<dyn TypedError<'r> + 'r>>) {
48+
// SAFETY: To meet the requirements of `Pin`, we never drop
49+
// the inner Box. This is enforced by only allowing writing
50+
// to the Option when it is None.
51+
assert!(self.error.is_none());
52+
if let Some(error) = error {
53+
self.error = Some(unsafe { Pin::new_unchecked(error) });
54+
}
55+
}
56+
57+
pub fn is_some(&self) -> bool {
58+
self.error.is_some()
59+
}
60+
61+
pub fn get(&'r self) -> Option<&'r dyn TypedError<'r>> {
62+
self.error.as_ref().map(|e| &**e)
63+
}
64+
}
65+
3866
// TODO: #[derive(Debug)]
3967
pub struct ErasedResponse {
4068
// XXX: SAFETY: This (dependent) field must come first due to drop order!
4169
response: Response<'static>,
42-
// XXX: SAFETY: This (dependent) field must come first due to drop order!
43-
error: Option<Box<dyn TypedError<'static> + 'static>>,
70+
// XXX: SAFETY: This (dependent) field must come second due to drop order!
71+
error: ErasedError<'static>,
4472
_request: Arc<ErasedRequest>,
4573
}
4674

@@ -71,8 +99,13 @@ impl ErasedRequest {
7199
let parts: Box<Parts> = Box::new(parts);
72100
let request: Request<'_> = {
73101
let rocket: &Rocket<Orbit> = &rocket;
102+
// SAFETY: The `Request` can borrow from `Rocket` because it has a stable
103+
// address (due to `Arc`) and it is kept alive by the containing
104+
// `ErasedRequest`. The `Request` is always dropped before the
105+
// `Arc<Rocket>` due to drop order.
74106
let rocket: &'static Rocket<Orbit> = unsafe { transmute(rocket) };
75107
let parts: &Parts = &parts;
108+
// SAFETY: Same as above, but for `Box<Parts>`.
76109
let parts: &'static Parts = unsafe { transmute(parts) };
77110
constructor(rocket, parts)
78111
};
@@ -92,46 +125,54 @@ impl ErasedRequest {
92125
&'r Rocket<Orbit>,
93126
&'r mut Request<'x>,
94127
&'r mut Data<'x>,
95-
&'r mut Option<Box<dyn TypedError<'r> + 'r>>,
128+
&'r mut ErasedError<'r>,
96129
) -> BoxFuture<'r, T>,
97130
dispatch: impl for<'r> FnOnce(
98131
T,
99132
&'r Rocket<Orbit>,
100133
&'r Request<'r>,
101134
Data<'r>,
102-
&'r mut Option<Box<dyn TypedError<'r> + 'r>>,
135+
&'r mut ErasedError<'r>,
103136
) -> BoxFuture<'r, Response<'r>>,
104137
) -> ErasedResponse
105138
where T: Send + Sync + 'static,
106139
D: for<'r> Into<RawStream<'r>>
107140
{
108-
let mut error_ptr: Option<Box<dyn TypedError<'static> + 'static>> = None;
109141
let mut data: Data<'_> = Data::from(raw_stream);
142+
// SAFETY: At this point, ErasedRequest contains a request, which is permitted
143+
// to borrow from `Rocket` and `Parts`. They both have stable addresses (due to
144+
// `Arc` and `Box`), and the Request will be dropped first (due to drop order).
145+
// SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc` (TODO: Why not Box?)
146+
// to ensure it has a stable address, and we again use drop order to ensure the `Request`
147+
// is dropped before the values that can borrow from it.
110148
let mut parent = Arc::new(self);
149+
// SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and `Parts`).
150+
let mut error = ErasedError { error: None };
111151
let token: T = {
112152
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
113153
let rocket: &Rocket<Orbit> = &parent._rocket;
114154
let request: &mut Request<'_> = &mut parent.request;
115155
let data: &mut Data<'_> = &mut data;
116-
// SAFETY: TODO: Same as below
117-
preprocess(rocket, request, data, unsafe { transmute(&mut error_ptr) }).await
156+
// SAFETY: As below, `error` must be reborrowed with the correct lifetimes.
157+
preprocess(rocket, request, data, unsafe { transmute(&mut error) }).await
118158
};
119159

120160
let parent = parent;
121161
let response: Response<'_> = {
122162
let parent: &ErasedRequest = &parent;
163+
// SAFETY: This static reference is immediatly reborrowed for the correct lifetime.
164+
// The Response type is permitted to borrow from the `Request`, `Rocket`, `Parts`, and
165+
// `error`. All of these types have stable addresses, and will not be dropped until
166+
// after Response, due to drop order.
123167
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
124168
let rocket: &Rocket<Orbit> = &parent._rocket;
125169
let request: &Request<'_> = &parent.request;
126-
// SAFETY: TODO: error_ptr is transmuted into the same type, with the
127-
// same lifetime as the request.
128-
// It is kept alive by the erased response, so that the response
129-
// type can borrow from it
130-
dispatch(token, rocket, request, data, unsafe { transmute(&mut error_ptr)}).await
170+
// SAFETY: As above, `error` must be reborrowed with the correct lifetimes.
171+
dispatch(token, rocket, request, data, unsafe { transmute(&mut error) }).await
131172
};
132173

133174
ErasedResponse {
134-
error: error_ptr,
175+
error,
135176
_request: parent,
136177
response,
137178
}
@@ -159,9 +200,21 @@ impl ErasedResponse {
159200
&'a mut Response<'r>,
160201
) -> Option<(T, Box<dyn IoHandler + 'r>)>
161202
) -> Option<(T, ErasedIoHandler)> {
203+
// SAFETY: If an error has been thrown, the `IoHandler` could
204+
// technically borrow from it, so we must ensure that this is
205+
// not the case. This could be handled safely by changing `error`
206+
// to be an `Arc` internally, and cloning the Arc to get a copy
207+
// (like `ErasedRequest`), however it's unclear this is actually
208+
// useful, and we can avoid paying the cost of an `Arc`
209+
if self.error.is_some() {
210+
warn!("Attempting to upgrade after throwing a typed error is not supported");
211+
return None;
212+
}
162213
let parent: Arc<ErasedRequest> = self._request.clone();
163214
let io: Option<(T, Box<dyn IoHandler + '_>)> = {
164215
let parent: &ErasedRequest = &parent;
216+
// SAFETY: As in other cases, the request is kept alive by the `Erased...`
217+
// type.
165218
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
166219
let request: &Request<'_> = &parent.request;
167220
constructor(request, &mut self.response)

core/lib/src/fairing/ad_hoc.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ enum AdHocKind {
6262

6363
/// An ad-hoc **request** fairing. Called when a request is received.
6464
Request(Box<dyn for<'a, 'b> Fn(&'a mut Request<'_>, &'b mut Data<'_>)
65+
-> BoxFuture<'a, ()> + Send + Sync + 'static>),
66+
67+
/// An ad-hoc **request_filter** fairing. Called when a request is received.
68+
RequestFilter(Box<dyn for<'a, 'b> Fn(&'a Request<'_>, &'b Data<'_>)
6569
-> BoxFuture<'a, Result<(), Box<dyn TypedError<'a> + 'a>>> + Send + Sync + 'static>),
6670

6771
/// An ad-hoc **response** fairing. Called when a response is ready to be
@@ -156,11 +160,35 @@ impl AdHoc {
156160
/// ```
157161
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
158162
where F: for<'a, 'b> Fn(&'a mut Request<'_>, &'b mut Data<'_>)
159-
-> BoxFuture<'a, Result<(), Box<dyn TypedError<'a> + 'a>>>
163+
-> BoxFuture<'a, ()>
160164
{
161165
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
162166
}
163167

168+
/// Constructs an `AdHoc` request fairing named `name`. The function `f`
169+
/// will be called and the returned `Future` will be `await`ed by Rocket
170+
/// when a new request is received.
171+
///
172+
/// # Example
173+
///
174+
/// ```rust
175+
/// use rocket::fairing::AdHoc;
176+
///
177+
/// // The no-op request fairing.
178+
/// let fairing = AdHoc::on_request("Dummy", |req, data| {
179+
/// Box::pin(async move {
180+
/// // do something with the request and data...
181+
/// # let (_, _) = (req, data);
182+
/// })
183+
/// });
184+
/// ```
185+
pub fn filter_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
186+
where F: for<'a, 'b> Fn(&'a Request<'_>, &'b Data<'_>)
187+
-> BoxFuture<'a, Result<(), Box<dyn TypedError<'a> + 'a>>>
188+
{
189+
AdHoc { name, kind: AdHocKind::RequestFilter(Box::new(f)) }
190+
}
191+
164192
// FIXME(rustc): We'd like to allow passing `async fn` to these methods...
165193
// https://github.com/rust-lang/rust/issues/64552#issuecomment-666084589
166194

@@ -379,12 +407,10 @@ impl AdHoc {
379407
let _ = self.routes(rocket);
380408
}
381409

382-
async fn on_request<'r>(&self, req: &'r mut Request<'_>, _: &mut Data<'_>)
383-
-> Result<(), Box<dyn TypedError<'r> + 'r>>
384-
{
410+
async fn on_request<'r>(&self, req: &'r mut Request<'_>, _: &mut Data<'_>) {
385411
// If the URI has no trailing slash, it routes as before.
386412
if req.uri().is_normalized_nontrailing() {
387-
return Ok(());
413+
return;
388414
}
389415

390416
// Otherwise, check if there's a route that matches the request
@@ -397,7 +423,6 @@ impl AdHoc {
397423
"incoming request URI normalized for compatibility");
398424
req.set_uri(normalized);
399425
}
400-
Ok(())
401426
}
402427
}
403428

@@ -412,6 +437,7 @@ impl Fairing for AdHoc {
412437
AdHocKind::Ignite(_) => Kind::Ignite,
413438
AdHocKind::Liftoff(_) => Kind::Liftoff,
414439
AdHocKind::Request(_) => Kind::Request,
440+
AdHocKind::RequestFilter(_) => Kind::RequestFilter,
415441
AdHocKind::Response(_) => Kind::Response,
416442
AdHocKind::Shutdown(_) => Kind::Shutdown,
417443
};
@@ -432,10 +458,16 @@ impl Fairing for AdHoc {
432458
}
433459
}
434460

435-
async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>)
461+
async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) {
462+
if let AdHocKind::Request(ref f) = self.kind {
463+
f(req, data).await
464+
}
465+
}
466+
467+
async fn filter_request<'r>(&self, req: &'r Request<'_>, data: &Data<'_>)
436468
-> Result<(), Box<dyn TypedError<'r> + 'r>>
437469
{
438-
if let AdHocKind::Request(ref f) = self.kind {
470+
if let AdHocKind::RequestFilter(ref f) = self.kind {
439471
f(req, data).await
440472
} else {
441473
Ok(())

core/lib/src/fairing/fairings.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::collections::HashSet;
2-
use std::mem::transmute;
32

4-
use crate::catcher::TypedError;
3+
use crate::erased::ErasedError;
54
use crate::{Rocket, Request, Response, Data, Build, Orbit};
65
use crate::fairing::{Fairing, Info, Kind};
76

@@ -17,6 +16,7 @@ pub struct Fairings {
1716
ignite: Vec<usize>,
1817
liftoff: Vec<usize>,
1918
request: Vec<usize>,
19+
filter_request: Vec<usize>,
2020
response: Vec<usize>,
2121
shutdown: Vec<usize>,
2222
}
@@ -45,6 +45,7 @@ impl Fairings {
4545
self.ignite.iter()
4646
.chain(self.liftoff.iter())
4747
.chain(self.request.iter())
48+
.chain(self.filter_request.iter())
4849
.chain(self.response.iter())
4950
.chain(self.shutdown.iter())
5051
}
@@ -106,6 +107,7 @@ impl Fairings {
106107
if this_info.kind.is(Kind::Ignite) { self.ignite.push(index); }
107108
if this_info.kind.is(Kind::Liftoff) { self.liftoff.push(index); }
108109
if this_info.kind.is(Kind::Request) { self.request.push(index); }
110+
if this_info.kind.is(Kind::RequestFilter) { self.filter_request.push(index); }
109111
if this_info.kind.is(Kind::Response) { self.response.push(index); }
110112
if this_info.kind.is(Kind::Shutdown) { self.shutdown.push(index); }
111113
}
@@ -153,17 +155,26 @@ impl Fairings {
153155
&self,
154156
req: &'r mut Request<'_>,
155157
data: &mut Data<'_>,
156-
error: &mut Option<Box<dyn TypedError<'_> + '_>>,
157158
) {
158159
for fairing in iter!(self.request) {
159-
// invoke_fairing(fairing, req, data, error)?;
160-
match fairing.on_request(req, data).await {
160+
fairing.on_request(req, data).await;
161+
}
162+
}
163+
164+
#[inline(always)]
165+
pub async fn handle_filter<'r>(
166+
&self,
167+
req: &'r Request<'_>,
168+
data: &Data<'_>,
169+
error: &mut ErasedError<'r>,
170+
) {
171+
for fairing in iter!(self.filter_request) {
172+
match fairing.filter_request(req, data).await {
161173
Ok(()) => (),
162174
Err(e) => {
163-
// TODO: Safety arguement
164-
// Generally, error is None at the start (hence no borrows),
165-
// and we always return immediatly with this value.
166-
*error = Some(unsafe { transmute(e) });
175+
// SAFETY: `e` can only contain *immutable* borrows of
176+
// `req`.
177+
error.write(Some(e));
167178
return;
168179
},
169180
}

core/lib/src/fairing/info_kind.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,18 @@ impl Kind {
6464
/// `Kind` flag representing a request for a 'request' callback.
6565
pub const Request: Kind = Kind(1 << 2);
6666

67+
/// `Kind` flag representing a request for a 'filter_request' callback.
68+
pub const RequestFilter: Kind = Kind(1 << 3);
69+
6770
/// `Kind` flag representing a request for a 'response' callback.
68-
pub const Response: Kind = Kind(1 << 3);
71+
pub const Response: Kind = Kind(1 << 4);
6972

7073
/// `Kind` flag representing a request for a 'shutdown' callback.
71-
pub const Shutdown: Kind = Kind(1 << 4);
74+
pub const Shutdown: Kind = Kind(1 << 5);
7275

7376
/// `Kind` flag representing a
7477
/// [singleton](crate::fairing::Fairing#singletons) fairing.
75-
pub const Singleton: Kind = Kind(1 << 5);
78+
pub const Singleton: Kind = Kind(1 << 6);
7679

7780
/// Returns `true` if `self` is a superset of `other`. In other words,
7881
/// returns `true` if all of the kinds in `other` are also in `self`.

core/lib/src/fairing/mod.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,18 @@ pub type Result<T = Rocket<Build>, E = Rocket<Build>> = std::result::Result<T, E
150150
/// [`Request`] and [`Data`] structures but has not routed the request. A
151151
/// request callback can modify the request at will and [`Data::peek()`]
152152
/// into the incoming data. It may not, however, abort or respond directly
153-
/// to the request; these issues are better handled via [request guards] or
154-
/// via response callbacks. Any modifications to a request are persisted and
155-
/// can potentially alter how a request is routed.
153+
/// to the request; these issues are better handled via [request guards],
154+
/// Request filters, or via response callbacks. Any modifications to a
155+
/// request are persisted and can potentially alter how a request is routed.
156+
///
157+
/// * **<a name="filter_request">Request filter</a> (`filter_request`)**
158+
///
159+
/// A request callback, represented by the [`Fairing::filter_request()`] method,
160+
/// called after `on_request` callbacks have run, but before any handlers have
161+
/// been attempted. This type of fairing can choose to prematurly reject requests,
162+
/// skipping handlers all together, and moving it straight to error handling. This
163+
/// should generally only be used to apply filter that apply to the entire server,
164+
/// e.g. CORS processing.
156165
///
157166
/// * **<a name="response">Response</a> (`on_response`)**
158167
///
@@ -502,7 +511,26 @@ pub trait Fairing: Send + Sync + Any + 'static {
502511
/// ## Default Implementation
503512
///
504513
/// The default implementation of this method does nothing.
505-
async fn on_request<'r>(&self, _req: &'r mut Request<'_>, _data: &mut Data<'_>)
514+
async fn on_request<'r>(&self, _req: &'r mut Request<'_>, _data: &mut Data<'_>) { }
515+
516+
/// The request filter callback.
517+
///
518+
/// See [Fairing Callbacks](#filter_request) for complete semantics.
519+
///
520+
/// This method is called when a new request is received if `Kind::RequestFilter`
521+
/// is in the `kind` field of the `Info` structure for this fairing. The
522+
/// `&Request` parameter is the incoming request, and the `&Data`
523+
/// parameter is the incoming data in the request.
524+
///
525+
/// If this method returns `Ok`, the request routed as normal (assuming no other
526+
/// fairing filters it). Otherwise, the request is routed to an error handler
527+
/// based on the error type returned.
528+
///
529+
/// ## Default Implementation
530+
///
531+
/// The default implementation of this method does not filter any request,
532+
/// by always returning `Ok(())`
533+
async fn filter_request<'r>(&self, _req: &'r Request<'_>, _data: &Data<'_>)
506534
-> Result<(), Box<dyn TypedError<'r> + 'r>>
507535
{ Ok(()) }
508536

@@ -554,10 +582,15 @@ impl<T: Fairing + ?Sized> Fairing for std::sync::Arc<T> {
554582
}
555583

556584
#[inline]
557-
async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>)
585+
async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) {
586+
(self as &T).on_request(req, data).await
587+
}
588+
589+
#[inline]
590+
async fn filter_request<'r>(&self, req: &'r Request<'_>, data: &Data<'_>)
558591
-> Result<(), Box<dyn TypedError<'r> + 'r>>
559592
{
560-
(self as &T).on_request(req, data).await
593+
(self as &T).filter_request(req, data).await
561594
}
562595

563596
#[inline]

0 commit comments

Comments
 (0)