Skip to content

Commit bf7c5fc

Browse files
mladedavjplatte
authored andcommitted
axum/routing: Merge fallbacks with the rest of the router
1 parent 64563fb commit bf7c5fc

File tree

5 files changed

+252
-104
lines changed

5 files changed

+252
-104
lines changed

axum/src/docs/routing/route.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ documentation for more details.
3636
It is not possible to create segments that only match some types like numbers or
3737
regular expression. You must handle that manually in your handlers.
3838

39-
[`MatchedPath`](crate::extract::MatchedPath) can be used to extract the matched
40-
path rather than the actual path.
39+
[`MatchedPath`] can be used to extract the matched path rather than the actual path.
4140

4241
# Wildcards
4342

axum/src/routing/mod.rs

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
44
#[cfg(feature = "tokio")]
55
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6+
#[cfg(feature = "matched-path")]
7+
use crate::extract::MatchedPath;
68
use crate::{
79
body::{Body, HttpBody},
810
boxed::BoxedIntoRoute,
@@ -20,7 +22,8 @@ use std::{
2022
sync::Arc,
2123
task::{Context, Poll},
2224
};
23-
use tower_layer::Layer;
25+
use tower::service_fn;
26+
use tower_layer::{layer_fn, Layer};
2427
use tower_service::Service;
2528

2629
pub mod future;
@@ -78,8 +81,7 @@ impl<S> Clone for Router<S> {
7881
}
7982

8083
struct RouterInner<S> {
81-
path_router: PathRouter<S, false>,
82-
fallback_router: PathRouter<S, true>,
84+
path_router: PathRouter<S>,
8385
default_fallback: bool,
8486
catch_all_fallback: Fallback<S>,
8587
}
@@ -97,7 +99,6 @@ impl<S> fmt::Debug for Router<S> {
9799
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98100
f.debug_struct("Router")
99101
.field("path_router", &self.inner.path_router)
100-
.field("fallback_router", &self.inner.fallback_router)
101102
.field("default_fallback", &self.inner.default_fallback)
102103
.field("catch_all_fallback", &self.inner.catch_all_fallback)
103104
.finish()
@@ -147,7 +148,6 @@ where
147148
Self {
148149
inner: Arc::new(RouterInner {
149150
path_router: Default::default(),
150-
fallback_router: PathRouter::new_fallback(),
151151
default_fallback: true,
152152
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
153153
}),
@@ -159,7 +159,6 @@ where
159159
Ok(inner) => inner,
160160
Err(arc) => RouterInner {
161161
path_router: arc.path_router.clone(),
162-
fallback_router: arc.fallback_router.clone(),
163162
default_fallback: arc.default_fallback,
164163
catch_all_fallback: arc.catch_all_fallback.clone(),
165164
},
@@ -213,8 +212,7 @@ where
213212

214213
let RouterInner {
215214
path_router,
216-
fallback_router,
217-
default_fallback,
215+
default_fallback: _,
218216
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
219217
// requests with an empty path. If we were to inherit the catch-all fallback
220218
// it would end up matching `/{path}/*` which doesn't match empty paths.
@@ -223,10 +221,6 @@ where
223221

224222
tap_inner!(self, mut this => {
225223
panic_on_err!(this.path_router.nest(path, path_router));
226-
227-
if !default_fallback {
228-
panic_on_err!(this.fallback_router.nest(path, fallback_router));
229-
}
230224
})
231225
}
232226

@@ -253,43 +247,33 @@ where
253247
where
254248
R: Into<Router<S>>,
255249
{
256-
const PANIC_MSG: &str =
257-
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
258-
259250
let other: Router<S> = other.into();
260251
let RouterInner {
261252
path_router,
262-
fallback_router: mut other_fallback,
263253
default_fallback,
264254
catch_all_fallback,
265255
} = other.into_inner();
266256

267257
map_inner!(self, mut this => {
268-
panic_on_err!(this.path_router.merge(path_router));
269-
270258
match (this.default_fallback, default_fallback) {
271259
// both have the default fallback
272260
// use the one from other
273-
(true, true) => {
274-
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
275-
}
261+
(true, true) => {}
276262
// this has default fallback, other has a custom fallback
277263
(true, false) => {
278-
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
279264
this.default_fallback = false;
280265
}
281266
// this has a custom fallback, other has a default
282267
(false, true) => {
283-
let fallback_router = std::mem::take(&mut this.fallback_router);
284-
other_fallback.merge(fallback_router).expect(PANIC_MSG);
285-
this.fallback_router = other_fallback;
286268
}
287269
// both have a custom fallback, not allowed
288270
(false, false) => {
289271
panic!("Cannot merge two `Router`s that both have a fallback")
290272
}
291273
};
292274

275+
panic_on_err!(this.path_router.merge(path_router));
276+
293277
this.catch_all_fallback = this
294278
.catch_all_fallback
295279
.merge(catch_all_fallback)
@@ -310,7 +294,6 @@ where
310294
{
311295
map_inner!(self, this => RouterInner {
312296
path_router: this.path_router.layer(layer.clone()),
313-
fallback_router: this.fallback_router.layer(layer.clone()),
314297
default_fallback: this.default_fallback,
315298
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
316299
})
@@ -328,7 +311,6 @@ where
328311
{
329312
map_inner!(self, this => RouterInner {
330313
path_router: this.path_router.route_layer(layer),
331-
fallback_router: this.fallback_router,
332314
default_fallback: this.default_fallback,
333315
catch_all_fallback: this.catch_all_fallback,
334316
})
@@ -397,8 +379,51 @@ where
397379
}
398380

399381
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
382+
// TODO make this better, get rid of the `unwrap`s.
383+
// We need the returned `Service` to be `Clone` and the function inside `service_fn` to be
384+
// `FnMut` so instead of just using the owned service, we do this trick with `Option`. We
385+
// know this will be called just once so it's fine. We're doing that so that we avoid one
386+
// clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not
387+
// cloned too much.
400388
tap_inner!(self, mut this => {
401-
this.fallback_router.set_fallback(endpoint);
389+
_ = this.path_router.route_endpoint(
390+
"/",
391+
endpoint.clone().layer(
392+
layer_fn(
393+
|service: Route| {
394+
let mut service = Some(service);
395+
service_fn(
396+
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
397+
move |mut request: Request| {
398+
#[cfg(feature = "matched-path")]
399+
request.extensions_mut().remove::<MatchedPath>();
400+
service.take().unwrap().oneshot_inner_owned(request)
401+
}
402+
)
403+
}
404+
)
405+
)
406+
);
407+
408+
_ = this.path_router.route_endpoint(
409+
FALLBACK_PARAM_PATH,
410+
endpoint.layer(
411+
layer_fn(
412+
|service: Route| {
413+
let mut service = Some(service);
414+
service_fn(
415+
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
416+
move |mut request: Request| {
417+
#[cfg(feature = "matched-path")]
418+
request.extensions_mut().remove::<MatchedPath>();
419+
service.take().unwrap().oneshot_inner_owned(request)
420+
}
421+
)
422+
}
423+
)
424+
)
425+
);
426+
402427
this.default_fallback = false;
403428
})
404429
}
@@ -407,7 +432,6 @@ where
407432
pub fn with_state<S2>(self, state: S) -> Router<S2> {
408433
map_inner!(self, this => RouterInner {
409434
path_router: this.path_router.with_state(state.clone()),
410-
fallback_router: this.fallback_router.with_state(state.clone()),
411435
default_fallback: this.default_fallback,
412436
catch_all_fallback: this.catch_all_fallback.with_state(state),
413437
})
@@ -419,11 +443,6 @@ where
419443
Err((req, state)) => (req, state),
420444
};
421445

422-
let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
423-
Ok(future) => return future,
424-
Err((req, state)) => (req, state),
425-
};
426-
427446
self.inner
428447
.catch_all_fallback
429448
.clone()

axum/src/routing/path_router.rs

Lines changed: 20 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,17 @@ use tower_layer::Layer;
99
use tower_service::Service;
1010

1111
use super::{
12-
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
13-
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
12+
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
13+
RouteId, NEST_TAIL_PARAM,
1414
};
1515

16-
pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
16+
pub(super) struct PathRouter<S> {
1717
routes: HashMap<RouteId, Endpoint<S>>,
1818
node: Arc<Node>,
1919
prev_route_id: RouteId,
2020
v7_checks: bool,
2121
}
2222

23-
impl<S> PathRouter<S, true>
24-
where
25-
S: Clone + Send + Sync + 'static,
26-
{
27-
pub(super) fn new_fallback() -> Self {
28-
let mut this = Self::default();
29-
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
30-
this
31-
}
32-
33-
pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
34-
self.replace_endpoint("/", endpoint.clone());
35-
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
36-
}
37-
}
38-
3923
fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
4024
if path.is_empty() {
4125
return Err("Paths must start with a `/`. Use \"/\" for root routes");
@@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
7256
.unwrap_or(Ok(()))
7357
}
7458

75-
impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
59+
impl<S> PathRouter<S>
7660
where
7761
S: Clone + Send + Sync + 'static,
7862
{
@@ -159,10 +143,7 @@ where
159143
.map_err(|err| format!("Invalid route {path:?}: {err}"))
160144
}
161145

162-
pub(super) fn merge(
163-
&mut self,
164-
other: PathRouter<S, IS_FALLBACK>,
165-
) -> Result<(), Cow<'static, str>> {
146+
pub(super) fn merge(&mut self, other: PathRouter<S>) -> Result<(), Cow<'static, str>> {
166147
let PathRouter {
167148
routes,
168149
node,
@@ -179,24 +160,9 @@ where
179160
.get(&id)
180161
.expect("no path for route id. This is a bug in axum. Please file an issue");
181162

182-
if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
183-
// when merging two routers it doesn't matter if you do `a.merge(b)` or
184-
// `b.merge(a)`. This must also be true for fallbacks.
185-
//
186-
// However all fallback routers will have routes for `/` and `/*` so when merging
187-
// we have to ignore the top level fallbacks on one side otherwise we get
188-
// conflicts.
189-
//
190-
// `Router::merge` makes sure that when merging fallbacks `other` always has the
191-
// fallback we want to keep. It panics if both routers have a custom fallback. Thus
192-
// it is always okay to ignore one fallback and `Router::merge` also makes sure the
193-
// one we can ignore is that of `self`.
194-
self.replace_endpoint(path, route);
195-
} else {
196-
match route {
197-
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
198-
Endpoint::Route(route) => self.route_service(path, route)?,
199-
}
163+
match route {
164+
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
165+
Endpoint::Route(route) => self.route_service(path, route)?,
200166
}
201167
}
202168

@@ -206,7 +172,7 @@ where
206172
pub(super) fn nest(
207173
&mut self,
208174
path_to_nest_at: &str,
209-
router: PathRouter<S, IS_FALLBACK>,
175+
router: PathRouter<S>,
210176
) -> Result<(), Cow<'static, str>> {
211177
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);
212178

@@ -282,7 +248,7 @@ where
282248
Ok(())
283249
}
284250

285-
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
251+
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S>
286252
where
287253
L: Layer<Route> + Clone + Send + Sync + 'static,
288254
L::Service: Service<Request> + Clone + Send + Sync + 'static,
@@ -344,7 +310,7 @@ where
344310
!self.routes.is_empty()
345311
}
346312

347-
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
313+
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2> {
348314
let routes = self
349315
.routes
350316
.into_iter()
@@ -389,14 +355,12 @@ where
389355
Ok(match_) => {
390356
let id = *match_.value;
391357

392-
if !IS_FALLBACK {
393-
#[cfg(feature = "matched-path")]
394-
crate::extract::matched_path::set_matched_path_for_request(
395-
id,
396-
&self.node.route_id_to_path,
397-
&mut parts.extensions,
398-
);
399-
}
358+
#[cfg(feature = "matched-path")]
359+
crate::extract::matched_path::set_matched_path_for_request(
360+
id,
361+
&self.node.route_id_to_path,
362+
&mut parts.extensions,
363+
);
400364

401365
url_params::insert_url_params(&mut parts.extensions, match_.params);
402366

@@ -419,18 +383,6 @@ where
419383
}
420384
}
421385

422-
pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
423-
match self.node.at(path) {
424-
Ok(match_) => {
425-
let id = *match_.value;
426-
self.routes.insert(id, endpoint);
427-
}
428-
Err(_) => self
429-
.route_endpoint(path, endpoint)
430-
.expect("path wasn't matched so endpoint shouldn't exist"),
431-
}
432-
}
433-
434386
fn next_route_id(&mut self) -> RouteId {
435387
let next_id = self
436388
.prev_route_id
@@ -442,7 +394,7 @@ where
442394
}
443395
}
444396

445-
impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
397+
impl<S> Default for PathRouter<S> {
446398
fn default() -> Self {
447399
Self {
448400
routes: Default::default(),
@@ -453,7 +405,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
453405
}
454406
}
455407

456-
impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
408+
impl<S> fmt::Debug for PathRouter<S> {
457409
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458410
f.debug_struct("PathRouter")
459411
.field("routes", &self.routes)
@@ -462,7 +414,7 @@ impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
462414
}
463415
}
464416

465-
impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
417+
impl<S> Clone for PathRouter<S> {
466418
fn clone(&self) -> Self {
467419
Self {
468420
routes: self.routes.clone(),

0 commit comments

Comments
 (0)