summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfuture-highway <113635015+future-highway@users.noreply.github.com>2023-12-29 06:06:47 -0500
committerGitHub <noreply@github.com>2023-12-29 12:06:47 +0100
commit56159b0d4e0db46b45d471bc291ab270c1e1dd77 (patch)
treedc43830767cd38a64eeb9939ac934de5eadb27fb
parentc3db223532f1f1006d2e0b0c576d627d6e95cdcb (diff)
JsonDeserializer extractor for zero-copy deserialization (#2431)
-rw-r--r--axum-extra/CHANGELOG.md2
-rw-r--r--axum-extra/Cargo.toml2
-rw-r--r--axum-extra/src/extract/json_deserializer.rs446
-rw-r--r--axum-extra/src/extract/mod.rs9
-rw-r--r--axum-extra/src/lib.rs1
-rw-r--r--axum/src/json.rs4
6 files changed, 462 insertions, 2 deletions
diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md
index 7120f6bc..c2d0a1f6 100644
--- a/axum-extra/CHANGELOG.md
+++ b/axum-extra/CHANGELOG.md
@@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- **change:** Update version of multer used internally for multipart ([#2433])
+- **added:** `JsonDeserializer` extractor ([#2431])
[#2433]: https://github.com/tokio-rs/axum/pull/2433
+[#2431]: https://github.com/tokio-rs/axum/pull/2431
# 0.9.0 (27. November, 2023)
diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml
index 04929dad..d2fa8993 100644
--- a/axum-extra/Cargo.toml
+++ b/axum-extra/Cargo.toml
@@ -21,6 +21,7 @@ cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json"]
form = ["dep:serde_html_form"]
+json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"]
json-lines = [
"dep:serde_json",
"dep:tokio-util",
@@ -60,6 +61,7 @@ percent-encoding = { version = "2.1", optional = true }
prost = { version = "0.12", optional = true }
serde_html_form = { version = "0.2.0", optional = true }
serde_json = { version = "1.0.71", optional = true }
+serde_path_to_error = { version = "0.1.8", optional = true }
tokio = { version = "1.19", optional = true }
tokio-stream = { version = "0.1.9", optional = true }
tokio-util = { version = "0.7", optional = true }
diff --git a/axum-extra/src/extract/json_deserializer.rs b/axum-extra/src/extract/json_deserializer.rs
new file mode 100644
index 00000000..0a307987
--- /dev/null
+++ b/axum-extra/src/extract/json_deserializer.rs
@@ -0,0 +1,446 @@
+use axum::async_trait;
+use axum::extract::{FromRequest, Request};
+use axum_core::__composite_rejection as composite_rejection;
+use axum_core::__define_rejection as define_rejection;
+use axum_core::extract::rejection::BytesRejection;
+use bytes::Bytes;
+use http::{header, HeaderMap};
+use serde::Deserialize;
+use std::marker::PhantomData;
+
+/// JSON Extractor for zero-copy deserialization.
+///
+/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`].
+/// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called.
+/// If the type implements [`serde::de::DeserializeOwned`], the [`Json`](axum::Json) extractor should
+/// be preferred.
+///
+/// The request will be rejected (and a [`JsonDeserializerRejection`] will be returned) if:
+///
+/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
+/// - Buffering the request body fails.
+///
+/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if:
+///
+/// - The body doesn't contain syntactically valid JSON.
+/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
+/// type.
+/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`).
+///
+/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the
+/// input contains escaped characters. Use `Cow<'a, str>` or `Cow<'a, [u8]>`, with the
+/// `#[serde(borrow)]` attribute, to allow serde to fall back to an owned type when encountering
+/// escaped characters.
+///
+/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be
+/// *last* if there are multiple extractors in a handler.
+/// See ["the order of extractors"][order-of-extractors]
+///
+/// [order-of-extractors]: axum::extract#the-order-of-extractors
+///
+/// See [`JsonDeserializerRejection`] for more details.
+///
+/// # Example
+///
+/// ```rust,no_run
+/// use axum::{
+/// routing::post,
+/// Router,
+/// response::{IntoResponse, Response}
+/// };
+/// use axum_extra::extract::JsonDeserializer;
+/// use serde::Deserialize;
+/// use std::borrow::Cow;
+/// use http::StatusCode;
+///
+/// #[derive(Deserialize)]
+/// struct Data<'a> {
+/// #[serde(borrow)]
+/// borrow_text: Cow<'a, str>,
+/// #[serde(borrow)]
+/// borrow_bytes: Cow<'a, [u8]>,
+/// borrow_dangerous: &'a str,
+/// not_borrowed: String,
+/// }
+///
+/// async fn upload(deserializer: JsonDeserializer<Data<'_>>) -> Response {
+/// let data = match deserializer.deserialize() {
+/// Ok(data) => data,
+/// Err(e) => return e.into_response(),
+/// };
+///
+/// // payload is a `Data` with borrowed data from `deserializer`,
+/// // which owns the request body (`Bytes`).
+///
+/// StatusCode::OK.into_response()
+/// }
+///
+/// let app = Router::new().route("/upload", post(upload));
+/// # let _: Router = app;
+/// ```
+#[derive(Debug, Clone, Default)]
+#[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
+pub struct JsonDeserializer<T> {
+ bytes: Bytes,
+ _marker: PhantomData<T>,
+}
+
+#[async_trait]
+impl<T, S> FromRequest<S> for JsonDeserializer<T>
+where
+ T: Deserialize<'static>,
+ S: Send + Sync,
+{
+ type Rejection = JsonDeserializerRejection;
+
+ async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
+ if json_content_type(req.headers()) {
+ let bytes = Bytes::from_request(req, state).await?;
+ Ok(Self {
+ bytes,
+ _marker: PhantomData,
+ })
+ } else {
+ Err(MissingJsonContentType.into())
+ }
+ }
+}
+
+impl<'de, 'a: 'de, T> JsonDeserializer<T>
+where
+ T: Deserialize<'de>,
+{
+ /// Deserialize the request body into the target type.
+ /// See [`JsonDeserializer`] for more details.
+ pub fn deserialize(&'a self) -> Result<T, JsonDeserializerRejection> {
+ let deserializer = &mut serde_json::Deserializer::from_slice(&self.bytes);
+
+ let value = match serde_path_to_error::deserialize(deserializer) {
+ Ok(value) => value,
+ Err(err) => {
+ let rejection = match err.inner().classify() {
+ serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
+ serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
+ JsonSyntaxError::from_err(err).into()
+ }
+ serde_json::error::Category::Io => {
+ if cfg!(debug_assertions) {
+ // we don't use `serde_json::from_reader` and instead always buffer
+ // bodies first, so we shouldn't encounter any IO errors
+ unreachable!()
+ } else {
+ JsonSyntaxError::from_err(err).into()
+ }
+ }
+ };
+ return Err(rejection);
+ }
+ };
+
+ Ok(value)
+ }
+}
+
+define_rejection! {
+ #[status = UNPROCESSABLE_ENTITY]
+ #[body = "Failed to deserialize the JSON body into the target type"]
+ #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
+ /// Rejection type for [`JsonDeserializer`].
+ ///
+ /// This rejection is used if the request body is syntactically valid JSON but couldn't be
+ /// deserialized into the target type.
+ pub struct JsonDataError(Error);
+}
+
+define_rejection! {
+ #[status = BAD_REQUEST]
+ #[body = "Failed to parse the request body as JSON"]
+ #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
+ /// Rejection type for [`JsonDeserializer`].
+ ///
+ /// This rejection is used if the request body didn't contain syntactically valid JSON.
+ pub struct JsonSyntaxError(Error);
+}
+
+define_rejection! {
+ #[status = UNSUPPORTED_MEDIA_TYPE]
+ #[body = "Expected request with `Content-Type: application/json`"]
+ #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
+ /// Rejection type for [`JsonDeserializer`] used if the `Content-Type`
+ /// header is missing.
+ pub struct MissingJsonContentType;
+}
+
+composite_rejection! {
+ /// Rejection used for [`JsonDeserializer`].
+ ///
+ /// Contains one variant for each way the [`JsonDeserializer`] extractor
+ /// can fail.
+ #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
+ pub enum JsonDeserializerRejection {
+ JsonDataError,
+ JsonSyntaxError,
+ MissingJsonContentType,
+ BytesRejection,
+ }
+}
+
+fn json_content_type(headers: &HeaderMap) -> bool {
+ let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
+ content_type
+ } else {
+ return false;
+ };
+
+ let content_type = if let Ok(content_type) = content_type.to_str() {
+ content_type
+ } else {
+ return false;
+ };
+
+ let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
+ mime
+ } else {
+ return false;
+ };
+
+ let is_json_content_type = mime.type_() == "application"
+ && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json"));
+
+ is_json_content_type
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test_helpers::*;
+ use axum::{
+ response::{IntoResponse, Response},
+ routing::post,
+ Router,
+ };
+ use http::StatusCode;
+ use serde::Deserialize;
+ use serde_json::{json, Value};
+ use std::borrow::Cow;
+
+ #[tokio::test]
+ async fn deserialize_body() {
+ #[derive(Debug, Deserialize)]
+ struct Input<'a> {
+ #[serde(borrow)]
+ foo: Cow<'a, str>,
+ }
+
+ async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
+ match deserializer.deserialize() {
+ Ok(input) => {
+ assert!(matches!(input.foo, Cow::Borrowed(_)));
+ input.foo.into_owned().into_response()
+ }
+ Err(e) => e.into_response(),
+ }
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+ let res = client.post("/").json(&json!({ "foo": "bar" })).send().await;
+ let body = res.text().await;
+
+ assert_eq!(body, "bar");
+ }
+
+ #[tokio::test]
+ async fn deserialize_body_escaped_to_cow() {
+ #[derive(Debug, Deserialize)]
+ struct Input<'a> {
+ #[serde(borrow)]
+ foo: Cow<'a, str>,
+ }
+
+ async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
+ match deserializer.deserialize() {
+ Ok(Input { foo }) => {
+ let Cow::Owned(foo) = foo else {
+ panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters")
+ };
+
+ foo.into_response()
+ }
+ Err(e) => e.into_response(),
+ }
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+
+ // The escaped characters prevent serde_json from borrowing.
+ let res = client
+ .post("/")
+ .json(&json!({ "foo": "\"bar\"" }))
+ .send()
+ .await;
+
+ let body = res.text().await;
+
+ assert_eq!(body, r#""bar""#);
+ }
+
+ #[tokio::test]
+ async fn deserialize_body_escaped_to_str() {
+ #[derive(Debug, Deserialize)]
+ struct Input<'a> {
+ // Explicit `#[serde(borrow)]` attribute is not required for `&str` or &[u8].
+ // See: https://serde.rs/lifetimes.html#borrowing-data-in-a-derived-impl
+ foo: &'a str,
+ }
+
+ async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
+ match deserializer.deserialize() {
+ Ok(Input { foo }) => foo.to_owned().into_response(),
+ Err(e) => e.into_response(),
+ }
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+
+ let res = client
+ .post("/")
+ .json(&json!({ "foo": "good" }))
+ .send()
+ .await;
+ let body = res.text().await;
+ assert_eq!(body, "good");
+
+ let res = client
+ .post("/")
+ .json(&json!({ "foo": "\"bad\"" }))
+ .send()
+ .await;
+ assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
+ let body_text = res.text().await;
+ assert_eq!(
+ body_text,
+ "Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16"
+ );
+ }
+
+ #[tokio::test]
+ async fn consume_body_to_json_requires_json_content_type() {
+ #[derive(Debug, Deserialize)]
+ struct Input<'a> {
+ #[allow(dead_code)]
+ foo: Cow<'a, str>,
+ }
+
+ async fn handler(_deserializer: JsonDeserializer<Input<'_>>) -> Response {
+ panic!("This handler should not be called")
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+ let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await;
+
+ let status = res.status();
+
+ assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
+ }
+
+ #[tokio::test]
+ async fn json_content_types() {
+ async fn valid_json_content_type(content_type: &str) -> bool {
+ println!("testing {content_type:?}");
+
+ async fn handler(_deserializer: JsonDeserializer<Value>) -> Response {
+ StatusCode::OK.into_response()
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let res = TestClient::new(app)
+ .post("/")
+ .header("content-type", content_type)
+ .body("{}")
+ .send()
+ .await;
+
+ res.status() == StatusCode::OK
+ }
+
+ assert!(valid_json_content_type("application/json").await);
+ assert!(valid_json_content_type("application/json; charset=utf-8").await);
+ assert!(valid_json_content_type("application/json;charset=utf-8").await);
+ assert!(valid_json_content_type("application/cloudevents+json").await);
+ assert!(!valid_json_content_type("text/json").await);
+ }
+
+ #[tokio::test]
+ async fn invalid_json_syntax() {
+ async fn handler(deserializer: JsonDeserializer<Value>) -> Response {
+ match deserializer.deserialize() {
+ Ok(_) => panic!("Should have matched `Err`"),
+ Err(e) => e.into_response(),
+ }
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+ let res = client
+ .post("/")
+ .body("{")
+ .header("content-type", "application/json")
+ .send()
+ .await;
+
+ assert_eq!(res.status(), StatusCode::BAD_REQUEST);
+ }
+
+ #[derive(Deserialize)]
+ struct Foo {
+ #[allow(dead_code)]
+ a: i32,
+ #[allow(dead_code)]
+ b: Vec<Bar>,
+ }
+
+ #[derive(Deserialize)]
+ struct Bar {
+ #[allow(dead_code)]
+ x: i32,
+ #[allow(dead_code)]
+ y: i32,
+ }
+
+ #[tokio::test]
+ async fn invalid_json_data() {
+ async fn handler(deserializer: JsonDeserializer<Foo>) -> Response {
+ match deserializer.deserialize() {
+ Ok(_) => panic!("Should have matched `Err`"),
+ Err(e) => e.into_response(),
+ }
+ }
+
+ let app = Router::new().route("/", post(handler));
+
+ let client = TestClient::new(app);
+ let res = client
+ .post("/")
+ .body("{\"a\": 1, \"b\": [{\"x\": 2}]}")
+ .header("content-type", "application/json")
+ .send()
+ .await;
+
+ assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
+ let body_text = res.text().await;
+ assert_eq!(
+ body_text,
+ "Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23"
+ );
+ }
+}
diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs
index 8435fc84..1f9974de 100644
--- a/axum-extra/src/extract/mod.rs
+++ b/axum-extra/src/extract/mod.rs
@@ -10,6 +10,9 @@ mod form;
#[cfg(feature = "cookie")]
pub mod cookie;
+#[cfg(feature = "json-deserializer")]
+mod json_deserializer;
+
#[cfg(feature = "query")]
mod query;
@@ -36,6 +39,12 @@ pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejecti
#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;
+#[cfg(feature = "json-deserializer")]
+pub use self::json_deserializer::{
+ JsonDataError, JsonDeserializer, JsonDeserializerRejection, JsonSyntaxError,
+ MissingJsonContentType,
+};
+
#[cfg(feature = "json-lines")]
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;
diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs
index 12aa2801..eb93b0a3 100644
--- a/axum-extra/src/lib.rs
+++ b/axum-extra/src/lib.rs
@@ -16,6 +16,7 @@
//! `cookie-key-expansion` | Enables the `Key::derive_from` method | No
//! `erased-json` | Enables the `ErasedJson` response | No
//! `form` | Enables the `Form` extractor | No
+//! `json-deserializer` | Enables the `JsonDeserializer` extractor | No
//! `json-lines` | Enables the `JsonLines` extractor and response | No
//! `multipart` | Enables the `Multipart` extractor | No
//! `protobuf` | Enables the `Protobuf` extractor and response | No
diff --git a/axum/src/json.rs b/axum/src/json.rs
index ebff242d..e96be5b8 100644
--- a/axum/src/json.rs
+++ b/axum/src/json.rs
@@ -12,12 +12,12 @@ use serde::{de::DeserializeOwned, Serialize};
/// JSON Extractor / Response.
///
/// When used as an extractor, it can deserialize request bodies into some type that
-/// implements [`serde::Deserialize`]. The request will be rejected (and a [`JsonRejection`] will
+/// implements [`serde::de::DeserializeOwned`]. The request will be rejected (and a [`JsonRejection`] will
/// be returned) if:
///
/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
/// - The body doesn't contain syntactically valid JSON.
-/// - The body contains syntactically valid JSON but it couldn't be deserialized into the target
+/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
/// type.
/// - Buffering the request body fails.
///