Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ macro_rules! panic_on_err {
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32);
pub(crate) struct RouteId(usize);

/// The router type for composing handlers and services.
///
Expand Down
84 changes: 30 additions & 54 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ use super::{
};

pub(super) struct PathRouter<S> {
routes: HashMap<RouteId, Endpoint<S>>,
routes: Vec<Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}

Expand Down Expand Up @@ -71,11 +70,11 @@ where
) -> Result<(), Cow<'static, str>> {
validate_path(self.v7_checks, path)?;

let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
.path_to_route_id
.get(path)
.and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
.and_then(|route_id| self.routes.get(route_id.0).map(|svc| (*route_id, svc)))
{
// if we're adding a new `MethodRouter` to a route that already has one just
// merge them. This makes `.route("/", get(_)).route("/", post(_))` work
Expand All @@ -84,15 +83,11 @@ where
.clone()
.merge_for_path(Some(path), method_router)?,
);
self.routes.insert(route_id, service);
return Ok(());
self.routes[route_id.0] = service;
} else {
Endpoint::MethodRouter(method_router)
};

let id = self.next_route_id();
self.set_node(path, id)?;
self.routes.insert(id, endpoint);
let endpoint = Endpoint::MethodRouter(method_router);
self.new_route(path, endpoint)?;
}

Ok(())
}
Expand All @@ -102,7 +97,7 @@ where
H: Handler<T, S>,
T: 'static,
{
for (_, endpoint) in self.routes.iter_mut() {
for endpoint in self.routes.iter_mut() {
if let Endpoint::MethodRouter(rt) = endpoint {
*rt = rt.clone().default_fallback(handler.clone());
}
Expand All @@ -129,9 +124,7 @@ where
) -> Result<(), Cow<'static, str>> {
validate_path(self.v7_checks, path)?;

let id = self.next_route_id();
self.set_node(path, id)?;
self.routes.insert(id, endpoint);
self.new_route(path, endpoint)?;

Ok(())
}
Expand All @@ -143,21 +136,28 @@ where
.map_err(|err| format!("Invalid route {path:?}: {err}"))
}

fn new_route(&mut self, path: &str, endpoint: Endpoint<S>) -> Result<(), String> {
let id = RouteId(self.routes.len());
self.set_node(path, id)?;
self.routes.push(endpoint);
Ok(())
}

pub(super) fn merge(&mut self, other: Self) -> Result<(), Cow<'static, str>> {
let Self {
routes,
node,
prev_route_id: _,
v7_checks,
} = other;

// If either of the two did not allow paths starting with `:` or `*`, do not allow them for the merged router either.
self.v7_checks |= v7_checks;

for (id, route) in routes {
for (id, route) in routes.into_iter().enumerate() {
let route_id = RouteId(id);
let path = node
.route_id_to_path
.get(&id)
.get(&route_id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

match route {
Expand All @@ -179,15 +179,15 @@ where
let Self {
routes,
node,
prev_route_id: _,
// Ignore the configuration of the nested router
v7_checks: _,
} = router;

for (id, endpoint) in routes {
for (id, endpoint) in routes.into_iter().enumerate() {
let route_id = RouteId(id);
let inner_path = node
.route_id_to_path
.get(&id)
.get(&route_id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

let path = path_for_nested_route(prefix, inner_path);
Expand Down Expand Up @@ -259,16 +259,12 @@ where
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.map(|endpoint| endpoint.layer(layer.clone()))
.collect();

Self {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}
Expand All @@ -292,16 +288,12 @@ where
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.map(|endpoint| endpoint.layer(layer.clone()))
.collect();

Self {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}
Expand All @@ -314,21 +306,17 @@ where
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let endpoint: Endpoint<S2> = match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
};
(id, endpoint)
.map(|endpoint| match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
})
.collect();

PathRouter {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}
Expand Down Expand Up @@ -366,7 +354,7 @@ where

let endpoint = self
.routes
.get(&id)
.get(id.0)
.expect("no route for id. This is a bug in axum. Please file an issue");

let req = Request::from_parts(parts, body);
Expand All @@ -382,24 +370,13 @@ where
Err(MatchError::NotFound) => Err((Request::from_parts(parts, body), state)),
}
}

fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
.0
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
self.prev_route_id = RouteId(next_id);
self.prev_route_id
}
}

impl<S> Default for PathRouter<S> {
fn default() -> Self {
Self {
routes: Default::default(),
node: Default::default(),
prev_route_id: RouteId(0),
v7_checks: true,
}
}
Expand All @@ -419,7 +396,6 @@ impl<S> Clone for PathRouter<S> {
Self {
routes: self.routes.clone(),
node: self.node.clone(),
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}
Expand Down