summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorDavid Pedersen <david.pdrsn@gmail.com>2022-08-17 17:13:31 +0200
committerGitHub <noreply@github.com>2022-08-17 15:13:31 +0000
commit423308de3c8f63cd50589ddc0b8fa414d28dbf27 (patch)
tree8d8f7ef54fc9540184ea90a32b538bcd85e7cfe2 /examples
parent90dbd52ee4d497a4e9871eeeeb156677f6a82c5f (diff)
Add type safe state extractor (#1155)
* begin threading the state through * Pass state to extractors * make state extractor work * make sure nesting with different states work * impl Service for MethodRouter<()> * Fix some of axum-macro's tests * Implement more traits for `State` * Update examples to use `State` * consistent naming of request body param * swap type params * Default the state param to () * fix docs references * Docs and handler state refactoring * docs clean ups * more consistent naming * when does MethodRouter implement Service? * add missing docs * use `Router`'s default state type param * changelog * don't use default type param for FromRequest and RequestParts probably safer for library authors so you don't accidentally forget * fix examples * minor docs tweaks * clarify how to convert handlers into services * group methods in one impl block * make sure merged `MethodRouter`s can access state * fix docs link * test merge with same state type * Document how to access state from middleware * Port cookie extractors to use state to extract keys (#1250) * Updates ECOSYSTEM with a new sample project (#1252) * Avoid unhelpful compiler suggestion (#1251) * fix docs typo * document how library authors should access state * Add `RequestParts::with_state` * fix example * apply suggestions from review * add relevant changes to axum-extra and axum-core changelogs * Add `route_service_with_tsr` * fix trybuild expectations * make sure `SpaRouter` works with routers that have state * Change order of type params on FromRequest and RequestParts * reverse order of `RequestParts::with_state` args to match type params * Add `FromRef` trait (#1268) * Add `FromRef` trait * Remove unnecessary type params * format * fix docs link * format examples * Avoid unnecessary `MethodRouter` * apply suggestions from review Co-authored-by: Dani Pardo <dani.pardo@inmensys.com> Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
Diffstat (limited to 'examples')
-rw-r--r--examples/async-graphql/src/main.rs45
-rw-r--r--examples/chat/src/main.rs9
-rw-r--r--examples/consume-body-in-extractor-or-middleware/src/main.rs11
-rw-r--r--examples/customize-extractor-error/src/main.rs5
-rw-r--r--examples/customize-path-rejection/src/main.rs5
-rw-r--r--examples/error-handling-and-dependency-injection/src/main.rs13
-rw-r--r--examples/global-404-handler/src/main.rs3
-rw-r--r--examples/jwt/src/main.rs5
-rw-r--r--examples/key-value-store/src/main.rs27
-rw-r--r--examples/oauth/src/main.rs47
-rw-r--r--examples/reverse-proxy/src/main.rs13
-rw-r--r--examples/routes-and-handlers-close-together/src/main.rs2
-rw-r--r--examples/sessions/src/main.rs18
-rw-r--r--examples/sqlx-postgres/src/main.rs24
-rw-r--r--examples/sse/src/main.rs2
-rw-r--r--examples/static-file-server/src/main.rs2
-rw-r--r--examples/tls-rustls/src/main.rs2
-rw-r--r--examples/todos/src/main.rs16
-rw-r--r--examples/tokio-postgres/src/main.rs25
-rw-r--r--examples/validator/src/main.rs5
-rw-r--r--examples/versioning/src/main.rs5
-rw-r--r--examples/websockets/src/main.rs2
22 files changed, 165 insertions, 121 deletions
diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs
new file mode 100644
index 00000000..a8d84cb9
--- /dev/null
+++ b/examples/async-graphql/src/main.rs
@@ -0,0 +1,45 @@
+//! Example async-graphql application.
+//!
+//! Run with
+//!
+//! ```not_rust
+//! cd examples && cargo run -p example-async-graphql
+//! ```
+
+mod starwars;
+
+use async_graphql::{
+ http::{playground_source, GraphQLPlaygroundConfig},
+ EmptyMutation, EmptySubscription, Request, Response, Schema,
+};
+use axum::{
+ extract::State,
+ response::{Html, IntoResponse},
+ routing::get,
+ Json, Router,
+};
+use starwars::{QueryRoot, StarWars, StarWarsSchema};
+
+async fn graphql_handler(schema: State<StarWarsSchema>, req: Json<Request>) -> Json<Response> {
+ schema.execute(req.0).await.into()
+}
+
+async fn graphql_playground() -> impl IntoResponse {
+ Html(playground_source(GraphQLPlaygroundConfig::new("/")))
+}
+
+#[tokio::main]
+async fn main() {
+ let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription)
+ .data(StarWars::new())
+ .finish();
+
+ let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler));
+
+ println!("Playground: http://localhost:3000");
+
+ axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
+ .serve(app.into_make_service())
+ .await
+ .unwrap();
+}
diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs
index 5107323e..092a7d96 100644
--- a/examples/chat/src/main.rs
+++ b/examples/chat/src/main.rs
@@ -9,7 +9,7 @@
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
- Extension,
+ State,
},
response::{Html, IntoResponse},
routing::get,
@@ -44,10 +44,9 @@ async fn main() {
let app_state = Arc::new(AppState { user_set, tx });
- let app = Router::new()
+ let app = Router::with_state(app_state)
.route("/", get(index))
- .route("/websocket", get(websocket_handler))
- .layer(Extension(app_state));
+ .route("/websocket", get(websocket_handler));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
@@ -59,7 +58,7 @@ async fn main() {
async fn websocket_handler(
ws: WebSocketUpgrade,
- Extension(state): Extension<Arc<AppState>>,
+ State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs
index 1fdd9022..be948375 100644
--- a/examples/consume-body-in-extractor-or-middleware/src/main.rs
+++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs
@@ -80,17 +80,22 @@ async fn handler(_: PrintRequestBody, body: Bytes) {
struct PrintRequestBody;
#[async_trait]
-impl FromRequest<BoxBody> for PrintRequestBody {
+impl<S> FromRequest<S, BoxBody> for PrintRequestBody
+where
+ S: Send + Clone,
+{
type Rejection = Response;
- async fn from_request(req: &mut RequestParts<BoxBody>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, BoxBody>) -> Result<Self, Self::Rejection> {
+ let state = req.state().clone();
+
let request = Request::from_request(req)
.await
.map_err(|err| err.into_response())?;
let request = buffer_request_body(request).await?;
- *req = RequestParts::new(request);
+ *req = RequestParts::with_state(state, request);
Ok(Self)
}
diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs
index be5ef595..20e3b4d4 100644
--- a/examples/customize-extractor-error/src/main.rs
+++ b/examples/customize-extractor-error/src/main.rs
@@ -56,8 +56,9 @@ struct User {
struct Json<T>(T);
#[async_trait]
-impl<B, T> FromRequest<B> for Json<T>
+impl<S, B, T> FromRequest<S, B> for Json<T>
where
+ S: Send,
// these trait bounds are copied from `impl FromRequest for axum::Json`
T: DeserializeOwned,
B: axum::body::HttpBody + Send,
@@ -66,7 +67,7 @@ where
{
type Rejection = (StatusCode, axum::Json<Value>);
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req).await {
Ok(value) => Ok(Self(value.0)),
Err(rejection) => {
diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs
index 4e268949..8330b95a 100644
--- a/examples/customize-path-rejection/src/main.rs
+++ b/examples/customize-path-rejection/src/main.rs
@@ -52,15 +52,16 @@ struct Params {
struct Path<T>(T);
#[async_trait]
-impl<B, T> FromRequest<B> for Path<T>
+impl<S, B, T> FromRequest<S, B> for Path<T>
where
// these trait bounds are copied from `impl FromRequest for axum::extract::path::Path`
T: DeserializeOwned + Send,
B: Send,
+ S: Send,
{
type Rejection = (StatusCode, axum::Json<PathError>);
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match axum::extract::Path::<T>::from_request(req).await {
Ok(value) => Ok(Self(value.0)),
Err(rejection) => {
diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs
index d92b43bf..914ae181 100644
--- a/examples/error-handling-and-dependency-injection/src/main.rs
+++ b/examples/error-handling-and-dependency-injection/src/main.rs
@@ -9,7 +9,7 @@
use axum::{
async_trait,
- extract::{Extension, Path},
+ extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
@@ -36,12 +36,9 @@ async fn main() {
let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo;
// Build our application with some routes
- let app = Router::new()
+ let app = Router::with_state(user_repo)
.route("/users/:id", get(users_show))
- .route("/users", post(users_create))
- // Add our `user_repo` to all request's extensions so handlers can access
- // it.
- .layer(Extension(user_repo));
+ .route("/users", post(users_create));
// Run our application
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@@ -59,7 +56,7 @@ async fn main() {
/// so it can be returned from handlers directly.
async fn users_show(
Path(user_id): Path<Uuid>,
- Extension(user_repo): Extension<DynUserRepo>,
+ State(user_repo): State<DynUserRepo>,
) -> Result<Json<User>, AppError> {
let user = user_repo.find(user_id).await?;
@@ -69,7 +66,7 @@ async fn users_show(
/// Handler for `POST /users`.
async fn users_create(
Json(params): Json<CreateUser>,
- Extension(user_repo): Extension<DynUserRepo>,
+ State(user_repo): State<DynUserRepo>,
) -> Result<Json<User>, AppError> {
let user = user_repo.create(params).await?;
diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs
index 385a0e21..a3a5ea15 100644
--- a/examples/global-404-handler/src/main.rs
+++ b/examples/global-404-handler/src/main.rs
@@ -5,7 +5,6 @@
//! ```
use axum::{
- handler::Handler,
http::StatusCode,
response::{Html, IntoResponse},
routing::get,
@@ -27,7 +26,7 @@ async fn main() {
let app = Router::new().route("/", get(handler));
// add a fallback service for handling routes to unknown paths
- let app = app.fallback(handler_404.into_service());
+ let app = app.fallback(handler_404);
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs
index 0ac4053e..8725581d 100644
--- a/examples/jwt/src/main.rs
+++ b/examples/jwt/src/main.rs
@@ -122,13 +122,14 @@ impl AuthBody {
}
#[async_trait]
-impl<B> FromRequest<B> for Claims
+impl<S, B> FromRequest<S, B> for Claims
where
+ S: Send,
B: Send,
{
type Rejection = AuthError;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs
index 0ad5b7f6..c65ee75a 100644
--- a/examples/key-value-store/src/main.rs
+++ b/examples/key-value-store/src/main.rs
@@ -9,7 +9,7 @@
use axum::{
body::Bytes,
error_handling::HandleErrorLayer,
- extract::{ContentLengthLimit, Extension, Path},
+ extract::{ContentLengthLimit, Path, State},
handler::Handler,
http::StatusCode,
response::IntoResponse,
@@ -39,8 +39,10 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();
+ let shared_state = SharedState::default();
+
// Build our application by composing routes
- let app = Router::new()
+ let app = Router::with_state(Arc::clone(&shared_state))
.route(
"/:key",
// Add compression to `kv_get`
@@ -50,7 +52,7 @@ async fn main() {
)
.route("/keys", get(list_keys))
// Nest our admin routes under `/admin`
- .nest("/admin", admin_routes())
+ .nest("/admin", admin_routes(shared_state))
// Add middleware to all routes
.layer(
ServiceBuilder::new()
@@ -60,7 +62,6 @@ async fn main() {
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
- .layer(Extension(SharedState::default()))
.into_inner(),
);
@@ -73,16 +74,16 @@ async fn main() {
.unwrap();
}
-type SharedState = Arc<RwLock<State>>;
+type SharedState = Arc<RwLock<AppState>>;
#[derive(Default)]
-struct State {
+struct AppState {
db: HashMap<String, Bytes>,
}
async fn kv_get(
Path(key): Path<String>,
- Extension(state): Extension<SharedState>,
+ State(state): State<SharedState>,
) -> Result<Bytes, StatusCode> {
let db = &state.read().unwrap().db;
@@ -96,12 +97,12 @@ async fn kv_get(
async fn kv_set(
Path(key): Path<String>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
- Extension(state): Extension<SharedState>,
+ State(state): State<SharedState>,
) {
state.write().unwrap().db.insert(key, bytes);
}
-async fn list_keys(Extension(state): Extension<SharedState>) -> String {
+async fn list_keys(State(state): State<SharedState>) -> String {
let db = &state.read().unwrap().db;
db.keys()
@@ -110,16 +111,16 @@ async fn list_keys(Extension(state): Extension<SharedState>) -> String {
.join("\n")
}
-fn admin_routes() -> Router {
- async fn delete_all_keys(Extension(state): Extension<SharedState>) {
+fn admin_routes(state: SharedState) -> Router<SharedState> {
+ async fn delete_all_keys(State(state): State<SharedState>) {
state.write().unwrap().db.clear();
}
- async fn remove_key(Path(key): Path<String>, Extension(state): Extension<SharedState>) {
+ async fn remove_key(Path(key): Path<String>, State(state): State<SharedState>) {
state.write().unwrap().db.remove(&key);
}
- Router::new()
+ Router::with_state(state)
.route("/keys", delete(delete_all_keys))
.route("/key/:key", delete(remove_key))
// Require bearer auth for all admin routes
diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs
index 6357a7fc..a61113b9 100644
--- a/examples/oauth/src/main.rs
+++ b/examples/oauth/src/main.rs
@@ -12,7 +12,7 @@ use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
extract::{
- rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts,
+ rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State,
TypedHeader,
},
http::{header::SET_COOKIE, HeaderMap},
@@ -42,17 +42,18 @@ async fn main() {
// `MemoryStore` is just used as an example. Don't use this in production.
let store = MemoryStore::new();
-
let oauth_client = oauth_client();
+ let app_state = AppState {
+ store,
+ oauth_client,
+ };
- let app = Router::new()
+ let app = Router::with_state(app_state)
.route("/", get(index))
.route("/auth/discord", get(discord_auth))
.route("/auth/authorized", get(login_authorized))
.route("/protected", get(protected))
- .route("/logout", get(logout))
- .layer(Extension(store))
- .layer(Extension(oauth_client));
+ .route("/logout", get(logout));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
@@ -63,6 +64,24 @@ async fn main() {
.unwrap();
}
+#[derive(Clone)]
+struct AppState {
+ store: MemoryStore,
+ oauth_client: BasicClient,
+}
+
+impl FromRef<AppState> for MemoryStore {
+ fn from_ref(state: &AppState) -> Self {
+ state.store.clone()
+ }
+}
+
+impl FromRef<AppState> for BasicClient {
+ fn from_ref(state: &AppState) -> Self {
+ state.oauth_client.clone()
+ }
+}
+
fn oauth_client() -> BasicClient {
// Environment variables (* = required):
// *"CLIENT_ID" "REPLACE_ME";
@@ -113,7 +132,7 @@ async fn index(user: Option<User>) -> impl IntoResponse {
}
}
-async fn discord_auth(Extension(client): Extension<BasicClient>) -> impl IntoResponse {
+async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
let (auth_url, _csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
@@ -132,7 +151,7 @@ async fn protected(user: User) -> impl IntoResponse {
}
async fn logout(
- Extension(store): Extension<MemoryStore>,
+ State(store): State<MemoryStore>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> impl IntoResponse {
let cookie = cookies.get(COOKIE_NAME).unwrap();
@@ -156,8 +175,8 @@ struct AuthRequest {
async fn login_authorized(
Query(query): Query<AuthRequest>,
- Extension(store): Extension<MemoryStore>,
- Extension(oauth_client): Extension<BasicClient>,
+ State(store): State<MemoryStore>,
+ State(oauth_client): State<BasicClient>,
) -> impl IntoResponse {
// Get an auth token
let token = oauth_client
@@ -205,17 +224,15 @@ impl IntoResponse for AuthRedirect {
}
#[async_trait]
-impl<B> FromRequest<B> for User
+impl<B> FromRequest<AppState, B> for User
where
B: Send,
{
// If anything goes wrong or no session is found, redirect to the auth page
type Rejection = AuthRedirect;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let Extension(store) = Extension::<MemoryStore>::from_request(req)
- .await
- .expect("`MemoryStore` extension is missing");
+ async fn from_request(req: &mut RequestParts<AppState, B>) -> Result<Self, Self::Rejection> {
+ let store = req.state().clone().store;
let cookies = TypedHeader::<headers::Cookie>::from_request(req)
.await
diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs
index a9d2a5c7..af74ea12 100644
--- a/examples/reverse-proxy/src/main.rs
+++ b/examples/reverse-proxy/src/main.rs
@@ -8,7 +8,7 @@
//! ```
use axum::{
- extract::Extension,
+ extract::State,
http::{uri::Uri, Request, Response},
routing::get,
Router,
@@ -24,9 +24,7 @@ async fn main() {
let client = Client::new();
- let app = Router::new()
- .route("/", get(handler))
- .layer(Extension(client));
+ let app = Router::with_state(client).route("/", get(handler));
let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
println!("reverse proxy listening on {}", addr);
@@ -36,12 +34,7 @@ async fn main() {
.unwrap();
}
-async fn handler(
- Extension(client): Extension<Client>,
- // NOTE: Make sure to put the request extractor last because once the request
- // is extracted, extensions can't be extracted anymore.
- mut req: Request<Body>,
-) -> Response<Body> {
+async fn handler(State(client): State<Client>, mut req: Request<Body>) -> Response<Body> {
let path = req.uri().path();
let path_query = req
.uri()
diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs
index 5e52ad7b..41aaa49d 100644
--- a/examples/routes-and-handlers-close-together/src/main.rs
+++ b/examples/routes-and-handlers-close-together/src/main.rs
@@ -49,6 +49,6 @@ fn post_foo() -> Router {
route("/foo", post(handler))
}
-fn route(path: &str, method_router: MethodRouter) -> Router {
+fn route(path: &str, method_router: MethodRouter<()>) -> Router {
Router::new().route(path, method_router)
}
diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs
index 3251122c..cd0d41a1 100644
--- a/examples/sessions/src/main.rs
+++ b/examples/sessions/src/main.rs
@@ -7,7 +7,7 @@
use async_session::{MemoryStore, Session, SessionStore as _};
use axum::{
async_trait,
- extract::{Extension, FromRequest, RequestParts, TypedHeader},
+ extract::{FromRequest, RequestParts, TypedHeader},
headers::Cookie,
http::{
self,
@@ -38,9 +38,7 @@ async fn main() {
// `MemoryStore` just used as an example. Don't use this in production.
let store = MemoryStore::new();
- let app = Router::new()
- .route("/", get(handler))
- .layer(Extension(store));
+ let app = Router::with_state(store).route("/", get(handler));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
@@ -82,20 +80,16 @@ enum UserIdFromSession {
}
#[async_trait]
-impl<B> FromRequest<B> for UserIdFromSession
+impl<B> FromRequest<MemoryStore, B> for UserIdFromSession
where
B: Send,
{
type Rejection = (StatusCode, &'static str);
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let Extension(store) = Extension::<MemoryStore>::from_request(req)
- .await
- .expect("`MemoryStore` extension missing");
+ async fn from_request(req: &mut RequestParts<MemoryStore, B>) -> Result<Self, Self::Rejection> {
+ let store = req.state().clone();
- let cookie = Option::<TypedHeader<Cookie>>::from_request(req)
- .await
- .unwrap();
+ let cookie = req.extract::<Option<TypedHeader<Cookie>>>().await.unwrap();
let session_cookie = cookie
.as_ref()
diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs
index 9d101618..6548cdeb 100644
--- a/examples/sqlx-postgres/src/main.rs
+++ b/examples/sqlx-postgres/src/main.rs
@@ -15,7 +15,7 @@
use axum::{
async_trait,
- extract::{Extension, FromRequest, RequestParts},
+ extract::{FromRequest, RequestParts, State},
http::StatusCode,
routing::get,
Router,
@@ -46,12 +46,10 @@ async fn main() {
.expect("can connect to database");
// build our application with some routes
- let app = Router::new()
- .route(
- "/",
- get(using_connection_pool_extractor).post(using_connection_extractor),
- )
- .layer(Extension(pool));
+ let app = Router::with_state(pool).route(
+ "/",
+ get(using_connection_pool_extractor).post(using_connection_extractor),
+ );
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@@ -62,9 +60,9 @@ async fn main() {
.unwrap();
}
-// we can extract the connection pool with `Extension`
+// we can extract the connection pool with `State`
async fn using_connection_pool_extractor(
- Extension(pool): Extension<PgPool>,
+ State(pool): State<PgPool>,
) -> Result<String, (StatusCode, String)> {
sqlx::query_scalar("select 'hello world from pg'")
.fetch_one(&pool)
@@ -77,16 +75,14 @@ async fn using_connection_pool_extractor(
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
#[async_trait]
-impl<B> FromRequest<B> for DatabaseConnection
+impl<B> FromRequest<PgPool, B> for DatabaseConnection
where
B: Send,
{
type Rejection = (StatusCode, String);
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let Extension(pool) = Extension::<PgPool>::from_request(req)
- .await
- .map_err(internal_error)?;
+ async fn from_request(req: &mut RequestParts<PgPool, B>) -> Result<Self, Self::Rejection> {
+ let pool = req.state().clone();
let conn = pool.acquire().await.map_err(internal_error)?;
diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs
index 4dbfcb46..66799711 100644
--- a/examples/sse/src/main.rs
+++ b/examples/sse/src/main.rs
@@ -41,7 +41,7 @@ async fn main() {
// build our application with a route
let app = Router::new()
- .fallback(static_files_service)
+ .fallback_service(static_files_service)
.route("/sse", get(sse_handler))
.layer(TraceLayer::new_for_http());
diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs
index 1862ecab..f7ac2bb9 100644
--- a/examples/static-file-server/src/main.rs
+++ b/examples/static-file-server/src/main.rs
@@ -34,7 +34,7 @@ async fn main() {
// as the fallback to a `Router`
let app: _ = Router::new()
.route("/foo", get(|| async { "Hi from /foo" }))
- .fallback(get_service(ServeDir::new(".")).handle_error(handle_error))
+ .fallback_service(get_service(ServeDir::new(".")).handle_error(handle_error))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs
index 40008a96..6eec0e9c 100644
--- a/examples/tls-rustls/src/main.rs
+++ b/examples/tls-rustls/src/main.rs
@@ -6,7 +6,7 @@
use axum::{
extract::Host,
- handler::Handler,
+ handler::HandlerWithoutStateExt,
http::{StatusCode, Uri},
response::Redirect,
routing::get,
diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs
index 9a33416b..b82a308d 100644
--- a/examples/todos/src/main.rs
+++ b/examples/todos/src/main.rs
@@ -15,7 +15,7 @@
use axum::{
error_handling::HandleErrorLayer,
- extract::{Extension, Path, Query},
+ extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, patch},
@@ -46,7 +46,7 @@ async fn main() {
let db = Db::default();
// Compose the routes
- let app = Router::new()
+ let app = Router::with_state(db)
.route("/todos", get(todos_index).post(todos_create))
.route("/todos/:id", patch(todos_update).delete(todos_delete))
// Add middleware to all routes
@@ -64,7 +64,6 @@ async fn main() {
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
- .layer(Extension(db))
.into_inner(),
);
@@ -85,7 +84,7 @@ pub struct Pagination {
async fn todos_index(
pagination: Option<Query<Pagination>>,
- Extension(db): Extension<Db>,
+ State(db): State<Db>,
) -> impl IntoResponse {
let todos = db.read().unwrap();
@@ -106,10 +105,7 @@ struct CreateTodo {
text: String,
}
-async fn todos_create(
- Json(input): Json<CreateTodo>,
- Extension(db): Extension<Db>,
-) -> impl IntoResponse {
+async fn todos_create(Json(input): Json<CreateTodo>, State(db): State<Db>) -> impl IntoResponse {
let todo = Todo {
id: Uuid::new_v4(),
text: input.text,
@@ -130,7 +126,7 @@ struct UpdateTodo {
async fn todos_update(
Path(id): Path<Uuid>,
Json(input): Json<UpdateTodo>,
- Extension(db): Extension<Db>,
+ State(db): State<Db>,
) -> Result<impl IntoResponse, StatusCode> {
let mut todo = db
.read()
@@ -152,7 +148,7 @@ async fn todos_update(
Ok(Json(todo))
}
-async fn todos_delete(Path(id): Path<Uuid>, Extension(db): Extension<Db>) -> impl IntoResponse {
+async fn todos_delete(Path(id): Path<Uuid>, State(db): State<Db>) -> impl IntoResponse {
if db.write().unwrap().remove(&id).is_some() {
StatusCode::NO_CONTENT
} else {
diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs
index 66b03a8f..e0c60453 100644
--- a/examples/tokio-postgres/src/main.rs
+++ b/examples/tokio-postgres/src/main.rs
@@ -6,7 +6,7 @@
use axum::{
async_trait,
- extract::{Extension, FromRequest, RequestParts},
+ extract::{FromRequest, RequestParts, State},
http::StatusCode,
routing::get,
Router,
@@ -33,12 +33,10 @@ async fn main() {
let pool = Pool::builder().build(manager).await.unwrap();
// build our application with some routes
- let app = Router::new()
- .route(
- "/",
- get(using_connection_pool_extractor).post(using_connection_extractor),
- )
- .layer(Extension(pool));
+ let app = Router::with_state(pool).route(
+ "/",
+ get(using_connection_pool_extractor).post(using_connection_extractor),
+ );
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@@ -51,9 +49,8 @@ async fn main() {
type ConnectionPool = Pool<PostgresConnectionManager<NoTls>>;
-// we can extract the connection pool with `Extension`
async fn using_connection_pool_extractor(
- Extension(pool): Extension<ConnectionPool>,
+ State(pool): State<ConnectionPool>,
) -> Result<String, (StatusCode, String)> {
let conn = pool.get().await.map_err(internal_error)?;
@@ -71,16 +68,16 @@ async fn using_connection_pool_extractor(
struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager<NoTls>>);
#[async_trait]
-impl<B> FromRequest<B> for DatabaseConnection
+impl<B> FromRequest<ConnectionPool, B> for DatabaseConnection
where
B: Send,
{
type Rejection = (StatusCode, String);
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let Extension(pool) = Extension::<ConnectionPool>::from_request(req)
- .await
- .map_err(internal_error)?;
+ async fn from_request(
+ req: &mut RequestParts<ConnectionPool, B>,
+ ) -> Result<Self, Self::Rejection> {
+ let pool = req.state().clone();
let conn = pool.get_owned().await.map_err(internal_error)?;
diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs
index c8ce8c08..8682eb85 100644
--- a/examples/validator/src/main.rs
+++ b/examples/validator/src/main.rs
@@ -60,16 +60,17 @@ async fn handler(ValidatedForm(input): ValidatedForm<NameInput>) -> Html<String>
pub struct ValidatedForm<T>(pub T);
#[async_trait]
-impl<T, B> FromRequest<B> for ValidatedForm<T>
+impl<T, S, B> FromRequest<S, B> for ValidatedForm<T>
where
T: DeserializeOwned + Validate,
+ S: Send,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = ServerError;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let Form(value) = Form::<T>::from_request(req).await?;
value.validate()?;
Ok(ValidatedForm(value))
diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs
index 48ade3c9..cf8e15f2 100644
--- a/examples/versioning/src/main.rs
+++ b/examples/versioning/src/main.rs
@@ -48,13 +48,14 @@ enum Version {
}
#[async_trait]
-impl<B> FromRequest<B> for Version
+impl<S, B> FromRequest<S, B> for Version
where
B: Send,
+ S: Send,
{
type Rejection = Response;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let params = Path::<HashMap<String, String>>::from_request(req)
.await
.map_err(IntoResponse::into_response)?;
diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs
index bbdcaa08..a317bfa4 100644
--- a/examples/websockets/src/main.rs
+++ b/examples/websockets/src/main.rs
@@ -37,7 +37,7 @@ async fn main() {
// build our application with some routes
let app = Router::new()
- .fallback(
+ .fallback_service(
get_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.handle_error(|error: std::io::Error| async move {
(