summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Pedersen <david.pdrsn@gmail.com>2023-04-25 15:44:29 +0200
committerGitHub <noreply@github.com>2023-04-25 15:44:29 +0200
commitd22c1688302d631b08cd16cba8779a0d8b713c3b (patch)
tree6a6acc54fefff10b024ff2a6e3237cb14f9133fe
parent5f51b5b0569f7488bd35b5829226e7a9f3d3994d (diff)
Fix fallback panic on CONNECT requests (#1958)
-rw-r--r--axum/CHANGELOG.md4
-rw-r--r--axum/src/routing/method_routing.rs10
-rw-r--r--axum/src/routing/mod.rs75
-rw-r--r--axum/src/routing/path_router.rs21
-rw-r--r--axum/src/routing/tests/mod.rs46
5 files changed, 117 insertions, 39 deletions
diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md
index c0f92e36..5eff31dc 100644
--- a/axum/CHANGELOG.md
+++ b/axum/CHANGELOG.md
@@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
-- None.
+- **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958])
+
+[#1958]: https://github.com/tokio-rs/axum/pull/1958
# 0.6.16 (18. April, 2023)
diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs
index cd942902..98683573 100644
--- a/axum/src/routing/method_routing.rs
+++ b/axum/src/routing/method_routing.rs
@@ -1119,15 +1119,7 @@ where
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
- let future = match fallback {
- Fallback::Default(route) | Fallback::Service(route) => {
- RouteFuture::from_future(route.oneshot_inner(req))
- }
- Fallback::BoxedHandler(handler) => {
- let mut route = handler.clone().into_route(state);
- RouteFuture::from_future(route.oneshot_inner(req))
- }
- };
+ let future = fallback.call_with_state(req, state);
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs
index a1d8c715..54dc772e 100644
--- a/axum/src/routing/mod.rs
+++ b/axum/src/routing/mod.rs
@@ -60,6 +60,7 @@ pub struct Router<S = (), B = Body> {
path_router: PathRouter<S, B, false>,
fallback_router: PathRouter<S, B, true>,
default_fallback: bool,
+ catch_all_fallback: Fallback<S, B>,
}
impl<S, B> Clone for Router<S, B> {
@@ -68,6 +69,7 @@ impl<S, B> Clone for Router<S, B> {
path_router: self.path_router.clone(),
fallback_router: self.fallback_router.clone(),
default_fallback: self.default_fallback,
+ catch_all_fallback: self.catch_all_fallback.clone(),
}
}
}
@@ -88,6 +90,7 @@ impl<S, B> fmt::Debug for Router<S, B> {
.field("path_router", &self.path_router)
.field("fallback_router", &self.fallback_router)
.field("default_fallback", &self.default_fallback)
+ .field("catch_all_fallback", &self.catch_all_fallback)
.finish()
}
}
@@ -106,14 +109,12 @@ where
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn new() -> Self {
- let mut this = Self {
+ Self {
path_router: Default::default(),
- fallback_router: Default::default(),
+ fallback_router: PathRouter::new_fallback(),
default_fallback: true,
- };
- this = this.fallback_service(NotFound);
- this.default_fallback = true;
- this
+ catch_all_fallback: Fallback::Default(Route::new(NotFound)),
+ }
}
#[doc = include_str!("../docs/routing/route.md")]
@@ -151,6 +152,10 @@ where
path_router,
fallback_router,
default_fallback,
+ // we don't need to inherit the catch-all fallback. It is only used for CONNECT
+ // requests with an empty path. If we were to inherit the catch-all fallback
+ // it would end up matching `/{path}/*` which doesn't match empty paths.
+ catch_all_fallback: _,
} = router;
panic_on_err!(self.path_router.nest(path, path_router));
@@ -184,6 +189,7 @@ where
path_router,
fallback_router: other_fallback,
default_fallback,
+ catch_all_fallback,
} = other.into();
panic_on_err!(self.path_router.merge(path_router));
@@ -208,6 +214,11 @@ where
}
};
+ self.catch_all_fallback = self
+ .catch_all_fallback
+ .merge(catch_all_fallback)
+ .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
+
self
}
@@ -223,8 +234,9 @@ where
{
Router {
path_router: self.path_router.layer(layer.clone()),
- fallback_router: self.fallback_router.layer(layer),
+ fallback_router: self.fallback_router.layer(layer.clone()),
default_fallback: self.default_fallback,
+ catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
}
}
@@ -242,36 +254,38 @@ where
path_router: self.path_router.route_layer(layer),
fallback_router: self.fallback_router,
default_fallback: self.default_fallback,
+ catch_all_fallback: self.catch_all_fallback,
}
}
#[track_caller]
#[doc = include_str!("../docs/routing/fallback.md")]
- pub fn fallback<H, T>(self, handler: H) -> Self
+ pub fn fallback<H, T>(mut self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
- let endpoint = Endpoint::MethodRouter(any(handler));
- self.fallback_endpoint(endpoint)
+ self.catch_all_fallback =
+ Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
+ self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
}
/// Add a fallback [`Service`] to the router.
///
/// See [`Router::fallback`] for more details.
- pub fn fallback_service<T>(self, service: T) -> Self
+ pub fn fallback_service<T>(mut self, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
- self.fallback_endpoint(Endpoint::Route(Route::new(service)))
+ let route = Route::new(service);
+ self.catch_all_fallback = Fallback::Service(route.clone());
+ self.fallback_endpoint(Endpoint::Route(route))
}
fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
- self.fallback_router.replace_endpoint("/", endpoint.clone());
- self.fallback_router
- .replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
+ self.fallback_router.set_fallback(endpoint);
self.default_fallback = false;
self
}
@@ -280,8 +294,9 @@ where
pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
Router {
path_router: self.path_router.with_state(state.clone()),
- fallback_router: self.fallback_router.with_state(state),
+ fallback_router: self.fallback_router.with_state(state.clone()),
default_fallback: self.default_fallback,
+ catch_all_fallback: self.catch_all_fallback.with_state(state),
}
}
@@ -307,19 +322,17 @@ where
.map(|SuperFallback(path_router)| path_router.into_inner());
if let Some(mut super_fallback) = super_fallback {
- return super_fallback
- .call_with_state(req, state)
- .unwrap_or_else(|_| unreachable!());
+ match super_fallback.call_with_state(req, state) {
+ Ok(future) => return future,
+ Err((req, state)) => {
+ return self.catch_all_fallback.call_with_state(req, state);
+ }
+ }
}
match self.fallback_router.call_with_state(req, state) {
Ok(future) => future,
- Err((_req, _state)) => {
- unreachable!(
- "the default fallback added in `Router::new` \
- matches everything"
- )
- }
+ Err((req, state)) => self.catch_all_fallback.call_with_state(req, state),
}
}
}
@@ -428,6 +441,18 @@ where
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
}
}
+
+ fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
+ match self {
+ Fallback::Default(route) | Fallback::Service(route) => {
+ RouteFuture::from_future(route.oneshot_inner(req))
+ }
+ Fallback::BoxedHandler(handler) => {
+ let mut route = handler.clone().into_route(state);
+ RouteFuture::from_future(route.oneshot_inner(req))
+ }
+ }
+ }
}
impl<S, B, E> Clone for Fallback<S, B, E> {
diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs
index ca618e75..e05a7997 100644
--- a/axum/src/routing/path_router.rs
+++ b/axum/src/routing/path_router.rs
@@ -7,8 +7,8 @@ use tower_layer::Layer;
use tower_service::Service;
use super::{
- future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
- RouteId, NEST_TAIL_PARAM,
+ future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
+ MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM,
};
pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> {
@@ -17,6 +17,23 @@ pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> {
prev_route_id: RouteId,
}
+impl<S, B> PathRouter<S, B, true>
+where
+ B: HttpBody + Send + 'static,
+ S: Clone + Send + Sync + 'static,
+{
+ pub(super) fn new_fallback() -> Self {
+ let mut this = Self::default();
+ this.set_fallback(Endpoint::Route(Route::new(NotFound)));
+ this
+ }
+
+ pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S, B>) {
+ self.replace_endpoint("/", endpoint.clone());
+ self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
+ }
+}
+
impl<S, B, const IS_FALLBACK: bool> PathRouter<S, B, IS_FALLBACK>
where
B: HttpBody + Send + 'static,
diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs
index 950d601a..ad7d5067 100644
--- a/axum/src/routing/tests/mod.rs
+++ b/axum/src/routing/tests/mod.rs
@@ -15,7 +15,11 @@ use crate::{
BoxError, Extension, Json, Router,
};
use futures_util::stream::StreamExt;
-use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, Request, Response, StatusCode, Uri};
+use http::{
+ header::CONTENT_LENGTH,
+ header::{ALLOW, HOST},
+ HeaderMap, Method, Request, Response, StatusCode, Uri,
+};
use hyper::Body;
use serde::Deserialize;
use serde_json::json;
@@ -26,7 +30,9 @@ use std::{
task::{Context, Poll},
time::Duration,
};
-use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder};
+use tower::{
+ service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
+};
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
use tower_service::Service;
@@ -984,3 +990,39 @@ async fn logging_rejections() {
])
)
}
+
+// https://github.com/tokio-rs/axum/issues/1955
+#[crate::test]
+async fn connect_going_to_custom_fallback() {
+ let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") });
+
+ let req = Request::builder()
+ .uri("example.com:443")
+ .method(Method::CONNECT)
+ .header(HOST, "example.com:443")
+ .body(Body::empty())
+ .unwrap();
+
+ let res = app.oneshot(req).await.unwrap();
+ assert_eq!(res.status(), StatusCode::NOT_FOUND);
+ let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap();
+ assert_eq!(text, "custom fallback");
+}
+
+// https://github.com/tokio-rs/axum/issues/1955
+#[crate::test]
+async fn connect_going_to_default_fallback() {
+ let app = Router::new();
+
+ let req = Request::builder()
+ .uri("example.com:443")
+ .method(Method::CONNECT)
+ .header(HOST, "example.com:443")
+ .body(Body::empty())
+ .unwrap();
+
+ let res = app.oneshot(req).await.unwrap();
+ assert_eq!(res.status(), StatusCode::NOT_FOUND);
+ let body = hyper::body::to_bytes(res).await.unwrap();
+ assert!(body.is_empty());
+}