diff options
author | David Pedersen <david.pdrsn@gmail.com> | 2023-04-25 15:44:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-25 15:44:29 +0200 |
commit | d22c1688302d631b08cd16cba8779a0d8b713c3b (patch) | |
tree | 6a6acc54fefff10b024ff2a6e3237cb14f9133fe | |
parent | 5f51b5b0569f7488bd35b5829226e7a9f3d3994d (diff) |
Fix fallback panic on CONNECT requests (#1958)
-rw-r--r-- | axum/CHANGELOG.md | 4 | ||||
-rw-r--r-- | axum/src/routing/method_routing.rs | 10 | ||||
-rw-r--r-- | axum/src/routing/mod.rs | 75 | ||||
-rw-r--r-- | axum/src/routing/path_router.rs | 21 | ||||
-rw-r--r-- | axum/src/routing/tests/mod.rs | 46 |
5 files changed, 117 insertions, 39 deletions
diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index c0f92e36..5eff31dc 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958]) + +[#1958]: https://github.com/tokio-rs/axum/pull/1958 # 0.6.16 (18. April, 2023) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index cd942902..98683573 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1119,15 +1119,7 @@ where call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - let future = match fallback { - Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) - } - Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) - } - }; + let future = fallback.call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index a1d8c715..54dc772e 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -60,6 +60,7 @@ pub struct Router<S = (), B = Body> { path_router: PathRouter<S, B, false>, fallback_router: PathRouter<S, B, true>, default_fallback: bool, + catch_all_fallback: Fallback<S, B>, } impl<S, B> Clone for Router<S, B> { @@ -68,6 +69,7 @@ impl<S, B> Clone for Router<S, B> { path_router: self.path_router.clone(), fallback_router: self.fallback_router.clone(), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.clone(), } } } @@ -88,6 +90,7 @@ impl<S, B> fmt::Debug for Router<S, B> { .field("path_router", &self.path_router) .field("fallback_router", &self.fallback_router) .field("default_fallback", &self.default_fallback) + .field("catch_all_fallback", &self.catch_all_fallback) .finish() } } @@ -106,14 +109,12 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { - let mut this = Self { + Self { path_router: Default::default(), - fallback_router: Default::default(), + fallback_router: PathRouter::new_fallback(), default_fallback: true, - }; - this = this.fallback_service(NotFound); - this.default_fallback = true; - this + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + } } #[doc = include_str!("../docs/routing/route.md")] @@ -151,6 +152,10 @@ where path_router, fallback_router, default_fallback, + // we don't need to inherit the catch-all fallback. It is only used for CONNECT + // requests with an empty path. If we were to inherit the catch-all fallback + // it would end up matching `/{path}/*` which doesn't match empty paths. + catch_all_fallback: _, } = router; panic_on_err!(self.path_router.nest(path, path_router)); @@ -184,6 +189,7 @@ where path_router, fallback_router: other_fallback, default_fallback, + catch_all_fallback, } = other.into(); panic_on_err!(self.path_router.merge(path_router)); @@ -208,6 +214,11 @@ where } }; + self.catch_all_fallback = self + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + self } @@ -223,8 +234,9 @@ where { Router { path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer), + fallback_router: self.fallback_router.layer(layer.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), } } @@ -242,36 +254,38 @@ where path_router: self.path_router.route_layer(layer), fallback_router: self.fallback_router, default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback, } } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback<H, T>(self, handler: H) -> Self + pub fn fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, { - let endpoint = Endpoint::MethodRouter(any(handler)); - self.fallback_endpoint(endpoint) + self.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service<T>(self, service: T) -> Self + pub fn fallback_service<T>(mut self, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - self.fallback_endpoint(Endpoint::Route(Route::new(service))) + let route = Route::new(service); + self.catch_all_fallback = Fallback::Service(route.clone()); + self.fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self { - self.fallback_router.replace_endpoint("/", endpoint.clone()); - self.fallback_router - .replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + self.fallback_router.set_fallback(endpoint); self.default_fallback = false; self } @@ -280,8 +294,9 @@ where pub fn with_state<S2>(self, state: S) -> Router<S2, B> { Router { path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state), + fallback_router: self.fallback_router.with_state(state.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.with_state(state), } } @@ -307,19 +322,17 @@ where .map(|SuperFallback(path_router)| path_router.into_inner()); if let Some(mut super_fallback) = super_fallback { - return super_fallback - .call_with_state(req, state) - .unwrap_or_else(|_| unreachable!()); + match super_fallback.call_with_state(req, state) { + Ok(future) => return future, + Err((req, state)) => { + return self.catch_all_fallback.call_with_state(req, state); + } + } } match self.fallback_router.call_with_state(req, state) { Ok(future) => future, - Err((_req, _state)) => { - unreachable!( - "the default fallback added in `Router::new` \ - matches everything" - ) - } + Err((req, state)) => self.catch_all_fallback.call_with_state(req, state), } } } @@ -428,6 +441,18 @@ where Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), } } + + fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> { + match self { + Fallback::Default(route) | Fallback::Service(route) => { + RouteFuture::from_future(route.oneshot_inner(req)) + } + Fallback::BoxedHandler(handler) => { + let mut route = handler.clone().into_route(state); + RouteFuture::from_future(route.oneshot_inner(req)) + } + } + } } impl<S, B, E> Clone for Fallback<S, B, E> { diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index ca618e75..e05a7997 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -7,8 +7,8 @@ use tower_layer::Layer; use tower_service::Service; use super::{ - future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, - RouteId, NEST_TAIL_PARAM, + future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, + MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM, }; pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> { @@ -17,6 +17,23 @@ pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> { prev_route_id: RouteId, } +impl<S, B> PathRouter<S, B, true> +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + pub(super) fn new_fallback() -> Self { + let mut this = Self::default(); + this.set_fallback(Endpoint::Route(Route::new(NotFound))); + this + } + + pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S, B>) { + self.replace_endpoint("/", endpoint.clone()); + self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + } +} + impl<S, B, const IS_FALLBACK: bool> PathRouter<S, B, IS_FALLBACK> where B: HttpBody + Send + 'static, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 950d601a..ad7d5067 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -15,7 +15,11 @@ use crate::{ BoxError, Extension, Json, Router, }; use futures_util::stream::StreamExt; -use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, Request, Response, StatusCode, Uri}; +use http::{ + header::CONTENT_LENGTH, + header::{ALLOW, HOST}, + HeaderMap, Method, Request, Response, StatusCode, Uri, +}; use hyper::Body; use serde::Deserialize; use serde_json::json; @@ -26,7 +30,9 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder}; +use tower::{ + service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt, +}; use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer}; use tower_service::Service; @@ -984,3 +990,39 @@ async fn logging_rejections() { ]) ) } + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_custom_fallback() { + let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") }); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap(); + assert_eq!(text, "custom fallback"); +} + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_default_fallback() { + let app = Router::new(); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert!(body.is_empty()); +} |