Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/aide/src/axum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,11 @@ where
let _ = transform(TransformOpenApi::new(api));

let needs_reset = in_context(|ctx| {
// Strip null types from query parameters if enabled
if ctx.strip_query_null_types {
crate::transform::strip_null_from_query_params_impl(api);
}

if !ctx.extract_schemas {
return false;
}
Expand Down
17 changes: 17 additions & 0 deletions crates/aide/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ pub fn all_error_responses(infer: bool) {
});
}

/// Automatically strip null types from query parameter schemas.
///
/// Query strings cannot express null values - a parameter is either
/// present with a value or absent. When enabled, null types are
/// automatically removed from query parameter schemas during finalization.
///
/// This is enabled by default.
pub fn strip_query_null_types(strip: bool) {
in_context(|ctx| {
ctx.strip_query_null_types = strip;
});
}

/// Reset the state of the thread-local context.
///
/// Currently clears:
Expand All @@ -110,6 +123,9 @@ pub struct GenContext {
/// Extract schemas.
pub(crate) extract_schemas: bool,

/// Strip null types from query parameter schemas.
pub(crate) strip_query_null_types: bool,

/// Status code for no content.
pub(crate) no_content_status: u16,

Expand All @@ -135,6 +151,7 @@ impl GenContext {
infer_responses: true,
all_error_responses: false,
extract_schemas: true,
strip_query_null_types: true,
show_error: default_error_filter,
error_handler: None,
no_content_status,
Expand Down
244 changes: 242 additions & 2 deletions crates/aide/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ use std::{any::type_name, marker::PhantomData};
use crate::{
generate::GenContext,
openapi::{
Components, Contact, Info, License, OpenApi, Operation, Parameter, PathItem, ReferenceOr,
Response, SecurityScheme, Server, StatusCode, Tag,
Components, Contact, Info, License, OpenApi, Operation, Parameter,
ParameterSchemaOrContent, PathItem, ReferenceOr, Response, SecurityScheme, Server,
StatusCode, Tag,
},
OperationInput,
};
use indexmap::IndexMap;
use serde::Serialize;
use serde_json::Value;

use crate::{
error::Error, generate::in_context, operation::OperationOutput, util::iter_operations_mut,
Expand Down Expand Up @@ -194,6 +196,23 @@ impl<'t> TransformOpenApi<'t> {
self
}

/// Strip null types from query parameter schemas.
///
/// Query strings cannot express null values - a parameter is either
/// present with a value or absent. This method removes `null` from
/// type arrays (e.g., `["string", "null"]` becomes `"string"`) and
/// unwraps `anyOf` variants containing null types.
///
/// Note: This is called automatically when the transform is finalized
/// (unless disabled via [`aide::generate::strip_query_null_types`]).
/// You only need to call this explicitly if you want to apply it
/// at a specific point in your transform chain.
#[tracing::instrument(skip_all)]
pub fn strip_null_from_query_params(self) -> Self {
strip_null_from_query_params_impl(self.api);
self
}

/// Add a security scheme.
#[allow(clippy::missing_panics_doc)]
pub fn security_scheme(mut self, name: &str, scheme: SecurityScheme) -> Self {
Expand Down Expand Up @@ -1320,3 +1339,224 @@ impl<'t> TransformCallback<'t> {
fn filter_no_duplicate_response(err: &Error) -> bool {
!matches!(err, Error::DefaultResponseExists | Error::ResponseExists(_))
}

pub(crate) fn strip_null_from_query_params_impl(api: &mut OpenApi) {
let Some(paths) = &mut api.paths else { return };

for (_, path_item) in &mut paths.paths {
let ReferenceOr::Item(path_item) = path_item else {
continue;
};

for (_, op) in iter_operations_mut(path_item) {
for param in &mut op.parameters {
let ReferenceOr::Item(Parameter::Query { parameter_data, .. }) = param else {
continue;
};

let ParameterSchemaOrContent::Schema(schema_obj) = &mut parameter_data.format
else {
continue;
};

strip_null_from_type(&mut schema_obj.json_schema);
}
}
}
}

fn strip_null_from_type(schema: &mut schemars::Schema) {
// Handle type: ["string", "null"] -> type: "string"
if let Some(Value::Array(types)) = schema.get_mut("type") {
let null_count = types.iter().filter(|t| *t == "null").count();
if null_count == 0 || null_count == types.len() {
return; // No nulls, or all nulls - don't modify
}
types.retain(|t| t != "null");
if types.len() == 1 {
*schema.get_mut("type").unwrap() = types.remove(0);
}
return;
}

// Handle anyOf: [..., {type: "null"}] -> remove null variants
let Some(Value::Array(items)) = schema.get_mut("anyOf") else {
return;
};

let is_null = |v: &Value| matches!(v.get("type"), Some(Value::String(s)) if s == "null");
let null_count = items.iter().filter(|item| is_null(item)).count();
if null_count == 0 || null_count == items.len() {
return; // No nulls, or all nulls - don't modify
}

items.retain(|item| !is_null(item));

// Single item remains: unwrap the anyOf
if items.len() == 1 {
if let Some(Value::Object(obj)) = items.pop() {
if let Some(schema_obj) = schema.as_object_mut() {
schema_obj.remove("anyOf");
for (key, value) in obj {
schema_obj.insert(key, value);
}
}
}
}
}

#[cfg(test)]
mod tests {
use crate::openapi::{
MediaType, OpenApi, Operation, Parameter, ParameterData, ParameterSchemaOrContent, Paths,
ReferenceOr, RequestBody, SchemaObject,
};
use indexmap::IndexMap;
use schemars::JsonSchema;
use serde_json::json;

use super::TransformOpenApi;

fn inline_schema_for<T: JsonSchema>() -> schemars::Schema {
let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
s.inline_subschemas = true;
});
let mut gen = settings.into_generator();
gen.subschema_for::<T>()
}

fn property_schema(struct_schema: &schemars::Schema, name: &str) -> schemars::Schema {
struct_schema
.get("properties")
.and_then(|p| p.get(name))
.unwrap_or_else(|| panic!("property {name:?} not found"))
.clone()
.try_into()
.unwrap()
}

fn build_api(params: Vec<Parameter>, body_schema: Option<schemars::Schema>) -> OpenApi {
let mut op = Operation::default();
op.parameters = params.into_iter().map(ReferenceOr::Item).collect();
if let Some(schema) = body_schema {
op.request_body = Some(ReferenceOr::Item(RequestBody {
content: IndexMap::from_iter([(
"application/json".into(),
MediaType {
schema: Some(SchemaObject {
json_schema: schema,
external_docs: None,
example: None,
}),
..Default::default()
},
)]),
..Default::default()
}));
}

let mut path_item = crate::openapi::PathItem::default();
path_item.get = Some(op);

OpenApi {
paths: Some(Paths {
paths: IndexMap::from([("/test".to_string(), ReferenceOr::Item(path_item))]),
extensions: IndexMap::new(),
}),
..OpenApi::default()
}
}

fn query_param(name: &str, schema: schemars::Schema) -> Parameter {
Parameter::Query {
parameter_data: ParameterData {
name: name.to_string(),
description: None,
required: false,
deprecated: None,
format: ParameterSchemaOrContent::Schema(SchemaObject {
json_schema: schema,
external_docs: None,
example: None,
}),
example: None,
examples: IndexMap::new(),
explode: None,
extensions: IndexMap::new(),
},
allow_reserved: false,
style: Default::default(),
allow_empty_value: None,
}
}

fn get_param_schema(api: &OpenApi, param_index: usize) -> &schemars::Schema {
let paths = api.paths.as_ref().unwrap();
let ReferenceOr::Item(path_item) = &paths.paths["/test"] else {
panic!("expected item");
};
let op = path_item.get.as_ref().unwrap();
let ReferenceOr::Item(param) = &op.parameters[param_index] else {
panic!("expected parameter item");
};
let ParameterSchemaOrContent::Schema(schema_obj) = &param.parameter_data_ref().format
else {
panic!("expected schema");
};
&schema_obj.json_schema
}

fn get_body_schema(api: &OpenApi) -> &schemars::Schema {
let paths = api.paths.as_ref().unwrap();
let ReferenceOr::Item(path_item) = &paths.paths["/test"] else {
panic!("expected item");
};
let op = path_item.get.as_ref().unwrap();
let Some(ReferenceOr::Item(body)) = &op.request_body else {
panic!("expected request body");
};
&body.content["application/json"]
.schema
.as_ref()
.unwrap()
.json_schema
}

#[test]
fn strip_null_from_query_params() {
#[derive(JsonSchema)]
#[allow(dead_code)]
struct QueryParams {
optional: Option<String>,
}

#[derive(JsonSchema)]
#[allow(dead_code)]
struct Body {
optional: Option<String>,
}

let query_field = property_schema(&inline_schema_for::<QueryParams>(), "optional");
let body_field = property_schema(&inline_schema_for::<Body>(), "optional");

// Both should start out nullable
assert_eq!(query_field.get("type"), Some(&json!(["string", "null"])));
assert_eq!(body_field.get("type"), Some(&json!(["string", "null"])));

let mut api = build_api(vec![query_param("optional", query_field)], Some(body_field));
let _ = TransformOpenApi::new(&mut api).strip_null_from_query_params();

// Query param should have null stripped
assert_eq!(
get_param_schema(&api, 0).get("type"),
Some(&json!("string"))
);

// Request body schema should be unchanged
assert_eq!(
get_body_schema(&api).get("type"),
Some(&json!(["string", "null"])),
"request body schema should retain its nullable type"
);
}
}