diff --git a/api/src/data.rs b/api/src/data.rs index f612239e..a78bb3c4 100644 --- a/api/src/data.rs +++ b/api/src/data.rs @@ -98,26 +98,112 @@ pub struct S2FormatHeader { #[cfg(feature = "axum")] pub mod extract { + use std::borrow::Cow; + use axum::{ - extract::{ - FromRequest, OptionalFromRequest, Request, - rejection::{BytesRejection, JsonRejection}, - }, + extract::{FromRequest, OptionalFromRequest, Request, rejection::BytesRejection}, response::{IntoResponse, Response}, }; use bytes::Bytes; use serde::de::DeserializeOwned; + /// Rejection type for JSON extraction, owned by s2-api. + #[derive(Debug)] + #[non_exhaustive] + pub enum JsonExtractionRejection { + SyntaxError { + status: http::StatusCode, + message: Cow<'static, str>, + }, + DataError { + status: http::StatusCode, + message: Cow<'static, str>, + }, + MissingContentType, + Other { + status: http::StatusCode, + message: Cow<'static, str>, + }, + } + + const MISSING_CONTENT_TYPE_MSG: &str = "Expected request with `Content-Type: application/json`"; + + impl JsonExtractionRejection { + pub fn body_text(&self) -> &str { + match self { + Self::SyntaxError { message, .. } + | Self::DataError { message, .. } + | Self::Other { message, .. } => message, + Self::MissingContentType => MISSING_CONTENT_TYPE_MSG, + } + } + + pub fn status(&self) -> http::StatusCode { + match self { + Self::SyntaxError { status, .. } + | Self::DataError { status, .. } + | Self::Other { status, .. } => *status, + Self::MissingContentType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE, + } + } + } + + impl std::fmt::Display for JsonExtractionRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.body_text()) + } + } + + impl std::error::Error for JsonExtractionRejection {} + + impl IntoResponse for JsonExtractionRejection { + fn into_response(self) -> Response { + let status = self.status(); + match self { + Self::SyntaxError { message, .. } + | Self::DataError { message, .. } + | Self::Other { message, .. } => match message { + Cow::Borrowed(s) => (status, s).into_response(), + Cow::Owned(s) => (status, s).into_response(), + }, + Self::MissingContentType => (status, MISSING_CONTENT_TYPE_MSG).into_response(), + } + } + } + + // TODO: remove when we stop delegating to axum::Json. + impl From for JsonExtractionRejection { + fn from(rej: axum::extract::rejection::JsonRejection) -> Self { + use axum::extract::rejection::JsonRejection::*; + match rej { + JsonDataError(e) => Self::DataError { + status: e.status(), + message: e.body_text().into(), + }, + JsonSyntaxError(e) => Self::SyntaxError { + status: e.status(), + message: e.body_text().into(), + }, + MissingJsonContentType(_) => Self::MissingContentType, + other => Self::Other { + status: other.status(), + message: other.body_text().into(), + }, + } + } + } + impl FromRequest for super::Json where S: Send + Sync, T: DeserializeOwned, { - type Rejection = JsonRejection; + type Rejection = JsonExtractionRejection; async fn from_request(req: Request, state: &S) -> Result { - let axum::Json(value) = - as FromRequest>::from_request(req, state).await?; + let axum::Json(value) = as FromRequest>::from_request(req, state) + .await + .map_err(JsonExtractionRejection::from)?; Ok(Self(value)) } } @@ -127,7 +213,7 @@ pub mod extract { S: Send + Sync, T: DeserializeOwned, { - type Rejection = JsonRejection; + type Rejection = JsonExtractionRejection; async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else { @@ -137,15 +223,20 @@ pub mod extract { .as_ref() .is_some_and(crate::mime::is_json) { - Err(JsonRejection::MissingJsonContentType(Default::default()))?; + return Err(JsonExtractionRejection::MissingContentType); } - let bytes = Bytes::from_request(req, state) - .await - .map_err(JsonRejection::BytesRejection)?; + let bytes = Bytes::from_request(req, state).await.map_err(|e| { + JsonExtractionRejection::Other { + status: e.status(), + message: e.body_text().into(), + } + })?; if bytes.is_empty() { return Ok(None); } - let value = axum::Json::::from_bytes(&bytes)?.0; + let value = axum::Json::::from_bytes(&bytes) + .map_err(JsonExtractionRejection::from)? + .0; Ok(Some(Self(value))) } } @@ -159,7 +250,7 @@ pub mod extract { S: Send + Sync, T: DeserializeOwned, { - type Rejection = JsonRejection; + type Rejection = JsonExtractionRejection; async fn from_request(req: Request, state: &S) -> Result { match as OptionalFromRequest>::from_request(req, state).await { @@ -203,4 +294,66 @@ pub mod extract { Ok(super::Proto(T::decode(bytes)?)) } } + + #[cfg(test)] + mod tests { + use super::*; + use crate::v1::stream::AppendInput; + + fn classify_json_error( + json: &[u8], + ) -> Result { + axum::Json::::from_bytes(json) + .map(|axum::Json(v)| v) + .map_err(JsonExtractionRejection::from) + } + + /// Verify that our rejection wrapper preserves axum's status code + /// classification for a variety of invalid JSON payloads. This same + /// table will be reused when switching to sonic-rs in PR 2. + #[test] + fn json_error_classification() { + let cases: &[(&[u8], http::StatusCode)] = &[ + // Syntax errors → 400 + (b"not json", http::StatusCode::BAD_REQUEST), + // `{}` is valid JSON but missing `records` — axum reports data error + // before checking trailing chars. + (b"{} trailing", http::StatusCode::UNPROCESSABLE_ENTITY), + (b"", http::StatusCode::BAD_REQUEST), + (b"{truncated", http::StatusCode::BAD_REQUEST), + // Data errors → 422 + (b"{}", http::StatusCode::UNPROCESSABLE_ENTITY), + ( + br#"{"records": "nope"}"#, + http::StatusCode::UNPROCESSABLE_ENTITY, + ), + ( + br#"{"records": [{"body": 123}]}"#, + http::StatusCode::UNPROCESSABLE_ENTITY, + ), + ]; + + for (input, expected_status) in cases { + let err = classify_json_error::(input).expect_err(&format!( + "expected error for {:?}", + String::from_utf8_lossy(input) + )); + assert_eq!( + err.status(), + *expected_status, + "wrong status for {:?}: got {}, body: {}", + String::from_utf8_lossy(input), + err.status(), + err.body_text(), + ); + } + } + + #[test] + fn valid_json_parses_successfully() { + let input = br#"{"records": [], "match_seq_num": null}"#; + let result = classify_json_error::(input); + assert!(result.is_ok()); + } + } } diff --git a/api/src/v1/stream/extract.rs b/api/src/v1/stream/extract.rs index 7b9ce684..d531c463 100644 --- a/api/src/v1/stream/extract.rs +++ b/api/src/v1/stream/extract.rs @@ -1,6 +1,5 @@ use axum::{ - Json, - extract::{FromRequest, FromRequestParts, Request, rejection::JsonRejection}, + extract::{FromRequest, FromRequestParts, Request}, response::{IntoResponse, Response}, }; use futures::StreamExt as _; @@ -13,7 +12,10 @@ use tokio_util::{codec::FramedRead, io::StreamReader}; use super::{AppendInput, AppendInputStreamError, AppendRequest, ReadRequest, proto, s2s}; use crate::{ - data::{Format, Proto, extract::ProtoRejection}, + data::{ + Format, Json, Proto, + extract::{JsonExtractionRejection, ProtoRejection}, + }, mime::JsonOrProto, v1::stream::sse::LastEventId, }; @@ -23,7 +25,7 @@ pub enum AppendRequestRejection { #[error(transparent)] HeaderRejection(#[from] HeaderRejection), #[error(transparent)] - JsonRejection(#[from] JsonRejection), + JsonRejection(#[from] JsonExtractionRejection), #[error(transparent)] ProtoRejection(#[from] ProtoRejection), #[error(transparent)] diff --git a/lite/src/handlers/v1/error.rs b/lite/src/handlers/v1/error.rs index d983d546..c845b053 100644 --- a/lite/src/handlers/v1/error.rs +++ b/lite/src/handlers/v1/error.rs @@ -1,9 +1,9 @@ use axum::{ - extract::rejection::{JsonRejection, PathRejection, QueryRejection}, + extract::rejection::{PathRejection, QueryRejection}, response::{IntoResponse, Response}, }; use s2_api::{ - data::extract::ProtoRejection, + data::extract::{JsonExtractionRejection, ProtoRejection}, v1::{ self as v1t, error::{ErrorCode, ErrorInfo, ErrorResponse, StandardError}, @@ -27,7 +27,7 @@ pub enum ServiceError { #[error(transparent)] QueryRejection(#[from] QueryRejection), #[error(transparent)] - JsonRejection(#[from] JsonRejection), + JsonRejection(#[from] JsonExtractionRejection), #[error(transparent)] ProtoRejection(#[from] ProtoRejection), #[error(transparent)] diff --git a/sdk/src/api.rs b/sdk/src/api.rs index 59b054b4..4d409424 100644 --- a/sdk/src/api.rs +++ b/sdk/src/api.rs @@ -1247,6 +1247,7 @@ mod tests { #[tokio::test] async fn dns_error_message_is_clear() { + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let config = crate::types::S2Config::new("test-token".to_owned()) .with_endpoints( crate::types::S2Endpoints::new(