From d22c1688302d631b08cd16cba8779a0d8b713c3b Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 25 Apr 2023 15:44:29 +0200 Subject: Fix fallback panic on CONNECT requests (#1958) --- axum/CHANGELOG.md | 4 +- axum/src/routing/method_routing.rs | 10 +---- axum/src/routing/mod.rs | 75 +++++++++++++++++++++++++------------- axum/src/routing/path_router.rs | 21 ++++++++++- axum/src/routing/tests/mod.rs | 46 ++++++++++++++++++++++- 5 files changed, 117 insertions(+), 39 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index c0f92e36..5eff31dc 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958]) + +[#1958]: https://github.com/tokio-rs/axum/pull/1958 # 0.6.16 (18. April, 2023) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index cd942902..98683573 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1119,15 +1119,7 @@ where call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - let future = match fallback { - Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) - } - Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) - } - }; + let future = fallback.call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index a1d8c715..54dc772e 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -60,6 +60,7 @@ pub struct Router { path_router: PathRouter, fallback_router: PathRouter, default_fallback: bool, + catch_all_fallback: Fallback, } impl Clone for Router { @@ -68,6 +69,7 @@ impl Clone for Router { path_router: self.path_router.clone(), fallback_router: self.fallback_router.clone(), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.clone(), } } } @@ -88,6 +90,7 @@ impl fmt::Debug for Router { .field("path_router", &self.path_router) .field("fallback_router", &self.fallback_router) .field("default_fallback", &self.default_fallback) + .field("catch_all_fallback", &self.catch_all_fallback) .finish() } } @@ -106,14 +109,12 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { - let mut this = Self { + Self { path_router: Default::default(), - fallback_router: Default::default(), + fallback_router: PathRouter::new_fallback(), default_fallback: true, - }; - this = this.fallback_service(NotFound); - this.default_fallback = true; - this + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + } } #[doc = include_str!("../docs/routing/route.md")] @@ -151,6 +152,10 @@ where path_router, fallback_router, default_fallback, + // we don't need to inherit the catch-all fallback. It is only used for CONNECT + // requests with an empty path. If we were to inherit the catch-all fallback + // it would end up matching `/{path}/*` which doesn't match empty paths. + catch_all_fallback: _, } = router; panic_on_err!(self.path_router.nest(path, path_router)); @@ -184,6 +189,7 @@ where path_router, fallback_router: other_fallback, default_fallback, + catch_all_fallback, } = other.into(); panic_on_err!(self.path_router.merge(path_router)); @@ -208,6 +214,11 @@ where } }; + self.catch_all_fallback = self + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + self } @@ -223,8 +234,9 @@ where { Router { path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer), + fallback_router: self.fallback_router.layer(layer.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), } } @@ -242,36 +254,38 @@ where path_router: self.path_router.route_layer(layer), fallback_router: self.fallback_router, default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback, } } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(self, handler: H) -> Self + pub fn fallback(mut self, handler: H) -> Self where H: Handler, T: 'static, { - let endpoint = Endpoint::MethodRouter(any(handler)); - self.fallback_endpoint(endpoint) + self.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(self, service: T) -> Self + pub fn fallback_service(mut self, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - self.fallback_endpoint(Endpoint::Route(Route::new(service))) + let route = Route::new(service); + self.catch_all_fallback = Fallback::Service(route.clone()); + self.fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { - self.fallback_router.replace_endpoint("/", endpoint.clone()); - self.fallback_router - .replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + self.fallback_router.set_fallback(endpoint); self.default_fallback = false; self } @@ -280,8 +294,9 @@ where pub fn with_state(self, state: S) -> Router { Router { path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state), + fallback_router: self.fallback_router.with_state(state.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.with_state(state), } } @@ -307,19 +322,17 @@ where .map(|SuperFallback(path_router)| path_router.into_inner()); if let Some(mut super_fallback) = super_fallback { - return super_fallback - .call_with_state(req, state) - .unwrap_or_else(|_| unreachable!()); + match super_fallback.call_with_state(req, state) { + Ok(future) => return future, + Err((req, state)) => { + return self.catch_all_fallback.call_with_state(req, state); + } + } } match self.fallback_router.call_with_state(req, state) { Ok(future) => future, - Err((_req, _state)) => { - unreachable!( - "the default fallback added in `Router::new` \ - matches everything" - ) - } + Err((req, state)) => self.catch_all_fallback.call_with_state(req, state), } } } @@ -428,6 +441,18 @@ where Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), } } + + fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + match self { + Fallback::Default(route) | Fallback::Service(route) => { + RouteFuture::from_future(route.oneshot_inner(req)) + } + Fallback::BoxedHandler(handler) => { + let mut route = handler.clone().into_route(state); + RouteFuture::from_future(route.oneshot_inner(req)) + } + } + } } impl Clone for Fallback { diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index ca618e75..e05a7997 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -7,8 +7,8 @@ use tower_layer::Layer; use tower_service::Service; use super::{ - future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, - RouteId, NEST_TAIL_PARAM, + future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, + MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM, }; pub(super) struct PathRouter { @@ -17,6 +17,23 @@ pub(super) struct PathRouter { prev_route_id: RouteId, } +impl PathRouter +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + pub(super) fn new_fallback() -> Self { + let mut this = Self::default(); + this.set_fallback(Endpoint::Route(Route::new(NotFound))); + this + } + + pub(super) fn set_fallback(&mut self, endpoint: Endpoint) { + self.replace_endpoint("/", endpoint.clone()); + self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + } +} + impl PathRouter where B: HttpBody + Send + 'static, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 950d601a..ad7d5067 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -15,7 +15,11 @@ use crate::{ BoxError, Extension, Json, Router, }; use futures_util::stream::StreamExt; -use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, Request, Response, StatusCode, Uri}; +use http::{ + header::CONTENT_LENGTH, + header::{ALLOW, HOST}, + HeaderMap, Method, Request, Response, StatusCode, Uri, +}; use hyper::Body; use serde::Deserialize; use serde_json::json; @@ -26,7 +30,9 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder}; +use tower::{ + service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt, +}; use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer}; use tower_service::Service; @@ -984,3 +990,39 @@ async fn logging_rejections() { ]) ) } + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_custom_fallback() { + let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") }); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap(); + assert_eq!(text, "custom fallback"); +} + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_default_fallback() { + let app = Router::new(); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert!(body.is_empty()); +} -- cgit v1.2.3-70-g09d2