diff options
author | David Pedersen <david.pdrsn@gmail.com> | 2022-08-17 17:13:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-17 15:13:31 +0000 |
commit | 423308de3c8f63cd50589ddc0b8fa414d28dbf27 (patch) | |
tree | 8d8f7ef54fc9540184ea90a32b538bcd85e7cfe2 /examples | |
parent | 90dbd52ee4d497a4e9871eeeeb156677f6a82c5f (diff) |
Add type safe state extractor (#1155)
* begin threading the state through
* Pass state to extractors
* make state extractor work
* make sure nesting with different states work
* impl Service for MethodRouter<()>
* Fix some of axum-macro's tests
* Implement more traits for `State`
* Update examples to use `State`
* consistent naming of request body param
* swap type params
* Default the state param to ()
* fix docs references
* Docs and handler state refactoring
* docs clean ups
* more consistent naming
* when does MethodRouter implement Service?
* add missing docs
* use `Router`'s default state type param
* changelog
* don't use default type param for FromRequest and RequestParts
probably safer for library authors so you don't accidentally forget
* fix examples
* minor docs tweaks
* clarify how to convert handlers into services
* group methods in one impl block
* make sure merged `MethodRouter`s can access state
* fix docs link
* test merge with same state type
* Document how to access state from middleware
* Port cookie extractors to use state to extract keys (#1250)
* Updates ECOSYSTEM with a new sample project (#1252)
* Avoid unhelpful compiler suggestion (#1251)
* fix docs typo
* document how library authors should access state
* Add `RequestParts::with_state`
* fix example
* apply suggestions from review
* add relevant changes to axum-extra and axum-core changelogs
* Add `route_service_with_tsr`
* fix trybuild expectations
* make sure `SpaRouter` works with routers that have state
* Change order of type params on FromRequest and RequestParts
* reverse order of `RequestParts::with_state` args to match type params
* Add `FromRef` trait (#1268)
* Add `FromRef` trait
* Remove unnecessary type params
* format
* fix docs link
* format examples
* Avoid unnecessary `MethodRouter`
* apply suggestions from review
Co-authored-by: Dani Pardo <dani.pardo@inmensys.com>
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
Diffstat (limited to 'examples')
22 files changed, 165 insertions, 121 deletions
diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs new file mode 100644 index 00000000..a8d84cb9 --- /dev/null +++ b/examples/async-graphql/src/main.rs @@ -0,0 +1,45 @@ +//! Example async-graphql application. +//! +//! Run with +//! +//! ```not_rust +//! cd examples && cargo run -p example-async-graphql +//! ``` + +mod starwars; + +use async_graphql::{ + http::{playground_source, GraphQLPlaygroundConfig}, + EmptyMutation, EmptySubscription, Request, Response, Schema, +}; +use axum::{ + extract::State, + response::{Html, IntoResponse}, + routing::get, + Json, Router, +}; +use starwars::{QueryRoot, StarWars, StarWarsSchema}; + +async fn graphql_handler(schema: State<StarWarsSchema>, req: Json<Request>) -> Json<Response> { + schema.execute(req.0).await.into() +} + +async fn graphql_playground() -> impl IntoResponse { + Html(playground_source(GraphQLPlaygroundConfig::new("/"))) +} + +#[tokio::main] +async fn main() { + let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription) + .data(StarWars::new()) + .finish(); + + let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler)); + + println!("Playground: http://localhost:3000"); + + axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); +} diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 5107323e..092a7d96 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -9,7 +9,7 @@ use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, - Extension, + State, }, response::{Html, IntoResponse}, routing::get, @@ -44,10 +44,9 @@ async fn main() { let app_state = Arc::new(AppState { user_set, tx }); - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) - .route("/websocket", get(websocket_handler)) - .layer(Extension(app_state)); + .route("/websocket", get(websocket_handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -59,7 +58,7 @@ async fn main() { async fn websocket_handler( ws: WebSocketUpgrade, - Extension(state): Extension<Arc<AppState>>, + State(state): State<Arc<AppState>>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 1fdd9022..be948375 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -80,17 +80,22 @@ async fn handler(_: PrintRequestBody, body: Bytes) { struct PrintRequestBody; #[async_trait] -impl FromRequest<BoxBody> for PrintRequestBody { +impl<S> FromRequest<S, BoxBody> for PrintRequestBody +where + S: Send + Clone, +{ type Rejection = Response; - async fn from_request(req: &mut RequestParts<BoxBody>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, BoxBody>) -> Result<Self, Self::Rejection> { + let state = req.state().clone(); + let request = Request::from_request(req) .await .map_err(|err| err.into_response())?; let request = buffer_request_body(request).await?; - *req = RequestParts::new(request); + *req = RequestParts::with_state(state, request); Ok(Self) } diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index be5ef595..20e3b4d4 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -56,8 +56,9 @@ struct User { struct Json<T>(T); #[async_trait] -impl<B, T> FromRequest<B> for Json<T> +impl<S, B, T> FromRequest<S, B> for Json<T> where + S: Send, // these trait bounds are copied from `impl FromRequest for axum::Json` T: DeserializeOwned, B: axum::body::HttpBody + Send, @@ -66,7 +67,7 @@ where { type Rejection = (StatusCode, axum::Json<Value>); - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { match axum::Json::<T>::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 4e268949..8330b95a 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -52,15 +52,16 @@ struct Params { struct Path<T>(T); #[async_trait] -impl<B, T> FromRequest<B> for Path<T> +impl<S, B, T> FromRequest<S, B> for Path<T> where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, B: Send, + S: Send, { type Rejection = (StatusCode, axum::Json<PathError>); - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { match axum::extract::Path::<T>::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index d92b43bf..914ae181 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -9,7 +9,7 @@ use axum::{ async_trait, - extract::{Extension, Path}, + extract::{Path, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, @@ -36,12 +36,9 @@ async fn main() { let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; // Build our application with some routes - let app = Router::new() + let app = Router::with_state(user_repo) .route("/users/:id", get(users_show)) - .route("/users", post(users_create)) - // Add our `user_repo` to all request's extensions so handlers can access - // it. - .layer(Extension(user_repo)); + .route("/users", post(users_create)); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -59,7 +56,7 @@ async fn main() { /// so it can be returned from handlers directly. async fn users_show( Path(user_id): Path<Uuid>, - Extension(user_repo): Extension<DynUserRepo>, + State(user_repo): State<DynUserRepo>, ) -> Result<Json<User>, AppError> { let user = user_repo.find(user_id).await?; @@ -69,7 +66,7 @@ async fn users_show( /// Handler for `POST /users`. async fn users_create( Json(params): Json<CreateUser>, - Extension(user_repo): Extension<DynUserRepo>, + State(user_repo): State<DynUserRepo>, ) -> Result<Json<User>, AppError> { let user = user_repo.create(params).await?; diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index 385a0e21..a3a5ea15 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - handler::Handler, http::StatusCode, response::{Html, IntoResponse}, routing::get, @@ -27,7 +26,7 @@ async fn main() { let app = Router::new().route("/", get(handler)); // add a fallback service for handling routes to unknown paths - let app = app.fallback(handler_404.into_service()); + let app = app.fallback(handler_404); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 0ac4053e..8725581d 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -122,13 +122,14 @@ impl AuthBody { } #[async_trait] -impl<B> FromRequest<B> for Claims +impl<S, B> FromRequest<S, B> for Claims where + S: Send, B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request(req) diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 0ad5b7f6..c65ee75a 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -9,7 +9,7 @@ use axum::{ body::Bytes, error_handling::HandleErrorLayer, - extract::{ContentLengthLimit, Extension, Path}, + extract::{ContentLengthLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, @@ -39,8 +39,10 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); + let shared_state = SharedState::default(); + // Build our application by composing routes - let app = Router::new() + let app = Router::with_state(Arc::clone(&shared_state)) .route( "/:key", // Add compression to `kv_get` @@ -50,7 +52,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes()) + .nest("/admin", admin_routes(shared_state)) // Add middleware to all routes .layer( ServiceBuilder::new() @@ -60,7 +62,6 @@ async fn main() { .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(SharedState::default())) .into_inner(), ); @@ -73,16 +74,16 @@ async fn main() { .unwrap(); } -type SharedState = Arc<RwLock<State>>; +type SharedState = Arc<RwLock<AppState>>; #[derive(Default)] -struct State { +struct AppState { db: HashMap<String, Bytes>, } async fn kv_get( Path(key): Path<String>, - Extension(state): Extension<SharedState>, + State(state): State<SharedState>, ) -> Result<Bytes, StatusCode> { let db = &state.read().unwrap().db; @@ -96,12 +97,12 @@ async fn kv_get( async fn kv_set( Path(key): Path<String>, ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb - Extension(state): Extension<SharedState>, + State(state): State<SharedState>, ) { state.write().unwrap().db.insert(key, bytes); } -async fn list_keys(Extension(state): Extension<SharedState>) -> String { +async fn list_keys(State(state): State<SharedState>) -> String { let db = &state.read().unwrap().db; db.keys() @@ -110,16 +111,16 @@ async fn list_keys(Extension(state): Extension<SharedState>) -> String { .join("\n") } -fn admin_routes() -> Router { - async fn delete_all_keys(Extension(state): Extension<SharedState>) { +fn admin_routes(state: SharedState) -> Router<SharedState> { + async fn delete_all_keys(State(state): State<SharedState>) { state.write().unwrap().db.clear(); } - async fn remove_key(Path(key): Path<String>, Extension(state): Extension<SharedState>) { + async fn remove_key(Path(key): Path<String>, State(state): State<SharedState>) { state.write().unwrap().db.remove(&key); } - Router::new() + Router::with_state(state) .route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require bearer auth for all admin routes diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 6357a7fc..a61113b9 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,7 +12,7 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts, + rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, @@ -42,17 +42,18 @@ async fn main() { // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); - let oauth_client = oauth_client(); + let app_state = AppState { + store, + oauth_client, + }; - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) - .route("/logout", get(logout)) - .layer(Extension(store)) - .layer(Extension(oauth_client)); + .route("/logout", get(logout)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -63,6 +64,24 @@ async fn main() { .unwrap(); } +#[derive(Clone)] +struct AppState { + store: MemoryStore, + oauth_client: BasicClient, +} + +impl FromRef<AppState> for MemoryStore { + fn from_ref(state: &AppState) -> Self { + state.store.clone() + } +} + +impl FromRef<AppState> for BasicClient { + fn from_ref(state: &AppState) -> Self { + state.oauth_client.clone() + } +} + fn oauth_client() -> BasicClient { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; @@ -113,7 +132,7 @@ async fn index(user: Option<User>) -> impl IntoResponse { } } -async fn discord_auth(Extension(client): Extension<BasicClient>) -> impl IntoResponse { +async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse { let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) @@ -132,7 +151,7 @@ async fn protected(user: User) -> impl IntoResponse { } async fn logout( - Extension(store): Extension<MemoryStore>, + State(store): State<MemoryStore>, TypedHeader(cookies): TypedHeader<headers::Cookie>, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); @@ -156,8 +175,8 @@ struct AuthRequest { async fn login_authorized( Query(query): Query<AuthRequest>, - Extension(store): Extension<MemoryStore>, - Extension(oauth_client): Extension<BasicClient>, + State(store): State<MemoryStore>, + State(oauth_client): State<BasicClient>, ) -> impl IntoResponse { // Get an auth token let token = oauth_client @@ -205,17 +224,15 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl<B> FromRequest<B> for User +impl<B> FromRequest<AppState, B> for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let Extension(store) = Extension::<MemoryStore>::from_request(req) - .await - .expect("`MemoryStore` extension is missing"); + async fn from_request(req: &mut RequestParts<AppState, B>) -> Result<Self, Self::Rejection> { + let store = req.state().clone().store; let cookies = TypedHeader::<headers::Cookie>::from_request(req) .await diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index a9d2a5c7..af74ea12 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -8,7 +8,7 @@ //! ``` use axum::{ - extract::Extension, + extract::State, http::{uri::Uri, Request, Response}, routing::get, Router, @@ -24,9 +24,7 @@ async fn main() { let client = Client::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(client)); + let app = Router::with_state(client).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); @@ -36,12 +34,7 @@ async fn main() { .unwrap(); } -async fn handler( - Extension(client): Extension<Client>, - // NOTE: Make sure to put the request extractor last because once the request - // is extracted, extensions can't be extracted anymore. - mut req: Request<Body>, -) -> Response<Body> { +async fn handler(State(client): State<Client>, mut req: Request<Body>) -> Response<Body> { let path = req.uri().path(); let path_query = req .uri() diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs index 5e52ad7b..41aaa49d 100644 --- a/examples/routes-and-handlers-close-together/src/main.rs +++ b/examples/routes-and-handlers-close-together/src/main.rs @@ -49,6 +49,6 @@ fn post_foo() -> Router { route("/foo", post(handler)) } -fn route(path: &str, method_router: MethodRouter) -> Router { +fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) } diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 3251122c..cd0d41a1 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,7 +7,7 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts, TypedHeader}, + extract::{FromRequest, RequestParts, TypedHeader}, headers::Cookie, http::{ self, @@ -38,9 +38,7 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(store)); + let app = Router::with_state(store).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -82,20 +80,16 @@ enum UserIdFromSession { } #[async_trait] -impl<B> FromRequest<B> for UserIdFromSession +impl<B> FromRequest<MemoryStore, B> for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let Extension(store) = Extension::<MemoryStore>::from_request(req) - .await - .expect("`MemoryStore` extension missing"); + async fn from_request(req: &mut RequestParts<MemoryStore, B>) -> Result<Self, Self::Rejection> { + let store = req.state().clone(); - let cookie = Option::<TypedHeader<Cookie>>::from_request(req) - .await - .unwrap(); + let cookie = req.extract::<Option<TypedHeader<Cookie>>>().await.unwrap(); let session_cookie = cookie .as_ref() diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 9d101618..6548cdeb 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -15,7 +15,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -46,12 +46,10 @@ async fn main() { .expect("can connect to database"); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -62,9 +60,9 @@ async fn main() { .unwrap(); } -// we can extract the connection pool with `Extension` +// we can extract the connection pool with `State` async fn using_connection_pool_extractor( - Extension(pool): Extension<PgPool>, + State(pool): State<PgPool>, ) -> Result<String, (StatusCode, String)> { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&pool) @@ -77,16 +75,14 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>); #[async_trait] -impl<B> FromRequest<B> for DatabaseConnection +impl<B> FromRequest<PgPool, B> for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let Extension(pool) = Extension::<PgPool>::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request(req: &mut RequestParts<PgPool, B>) -> Result<Self, Self::Rejection> { + let pool = req.state().clone(); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 4dbfcb46..66799711 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -41,7 +41,7 @@ async fn main() { // build our application with a route let app = Router::new() - .fallback(static_files_service) + .fallback_service(static_files_service) .route("/sse", get(sse_handler)) .layer(TraceLayer::new_for_http()); diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 1862ecab..f7ac2bb9 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -34,7 +34,7 @@ async fn main() { // as the fallback to a `Router` let app: _ = Router::new() .route("/foo", get(|| async { "Hi from /foo" })) - .fallback(get_service(ServeDir::new(".")).handle_error(handle_error)) + .fallback_service(get_service(ServeDir::new(".")).handle_error(handle_error)) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index 40008a96..6eec0e9c 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -6,7 +6,7 @@ use axum::{ extract::Host, - handler::Handler, + handler::HandlerWithoutStateExt, http::{StatusCode, Uri}, response::Redirect, routing::get, diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 9a33416b..b82a308d 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -15,7 +15,7 @@ use axum::{ error_handling::HandleErrorLayer, - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, patch}, @@ -46,7 +46,7 @@ async fn main() { let db = Db::default(); // Compose the routes - let app = Router::new() + let app = Router::with_state(db) .route("/todos", get(todos_index).post(todos_create)) .route("/todos/:id", patch(todos_update).delete(todos_delete)) // Add middleware to all routes @@ -64,7 +64,6 @@ async fn main() { })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(db)) .into_inner(), ); @@ -85,7 +84,7 @@ pub struct Pagination { async fn todos_index( pagination: Option<Query<Pagination>>, - Extension(db): Extension<Db>, + State(db): State<Db>, ) -> impl IntoResponse { let todos = db.read().unwrap(); @@ -106,10 +105,7 @@ struct CreateTodo { text: String, } -async fn todos_create( - Json(input): Json<CreateTodo>, - Extension(db): Extension<Db>, -) -> impl IntoResponse { +async fn todos_create(Json(input): Json<CreateTodo>, State(db): State<Db>) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, @@ -130,7 +126,7 @@ struct UpdateTodo { async fn todos_update( Path(id): Path<Uuid>, Json(input): Json<UpdateTodo>, - Extension(db): Extension<Db>, + State(db): State<Db>, ) -> Result<impl IntoResponse, StatusCode> { let mut todo = db .read() @@ -152,7 +148,7 @@ async fn todos_update( Ok(Json(todo)) } -async fn todos_delete(Path(id): Path<Uuid>, Extension(db): Extension<Db>) -> impl IntoResponse { +async fn todos_delete(Path(id): Path<Uuid>, State(db): State<Db>) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index 66b03a8f..e0c60453 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -6,7 +6,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -33,12 +33,10 @@ async fn main() { let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -51,9 +49,8 @@ async fn main() { type ConnectionPool = Pool<PostgresConnectionManager<NoTls>>; -// we can extract the connection pool with `Extension` async fn using_connection_pool_extractor( - Extension(pool): Extension<ConnectionPool>, + State(pool): State<ConnectionPool>, ) -> Result<String, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; @@ -71,16 +68,16 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager<NoTls>>); #[async_trait] -impl<B> FromRequest<B> for DatabaseConnection +impl<B> FromRequest<ConnectionPool, B> for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let Extension(pool) = Extension::<ConnectionPool>::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request( + req: &mut RequestParts<ConnectionPool, B>, + ) -> Result<Self, Self::Rejection> { + let pool = req.state().clone(); let conn = pool.get_owned().await.map_err(internal_error)?; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index c8ce8c08..8682eb85 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -60,16 +60,17 @@ async fn handler(ValidatedForm(input): ValidatedForm<NameInput>) -> Html<String> pub struct ValidatedForm<T>(pub T); #[async_trait] -impl<T, B> FromRequest<B> for ValidatedForm<T> +impl<T, S, B> FromRequest<S, B> for ValidatedForm<T> where T: DeserializeOwned + Validate, + S: Send, B: http_body::Body + Send, B::Data: Send, B::Error: Into<BoxError>, { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { let Form(value) = Form::<T>::from_request(req).await?; value.validate()?; Ok(ValidatedForm(value)) diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 48ade3c9..cf8e15f2 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -48,13 +48,14 @@ enum Version { } #[async_trait] -impl<B> FromRequest<B> for Version +impl<S, B> FromRequest<S, B> for Version where B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { let params = Path::<HashMap<String, String>>::from_request(req) .await .map_err(IntoResponse::into_response)?; diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index bbdcaa08..a317bfa4 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -37,7 +37,7 @@ async fn main() { // build our application with some routes let app = Router::new() - .fallback( + .fallback_service( get_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .handle_error(|error: std::io::Error| async move { ( |