diff --git a/crates/aide/src/axum/mod.rs b/crates/aide/src/axum/mod.rs index 2fbb61f4..6a6bdb12 100644 --- a/crates/aide/src/axum/mod.rs +++ b/crates/aide/src/axum/mod.rs @@ -443,6 +443,11 @@ where let _ = transform(TransformOpenApi::new(api)); let needs_reset = in_context(|ctx| { + // Strip null types from query parameters if enabled + if ctx.strip_query_null_types { + crate::transform::strip_null_from_query_params_impl(api); + } + if !ctx.extract_schemas { return false; } diff --git a/crates/aide/src/generate.rs b/crates/aide/src/generate.rs index eb3295fc..2701b17a 100644 --- a/crates/aide/src/generate.rs +++ b/crates/aide/src/generate.rs @@ -84,6 +84,19 @@ pub fn all_error_responses(infer: bool) { }); } +/// Automatically strip null types from query parameter schemas. +/// +/// Query strings cannot express null values - a parameter is either +/// present with a value or absent. When enabled, null types are +/// automatically removed from query parameter schemas during finalization. +/// +/// This is enabled by default. +pub fn strip_query_null_types(strip: bool) { + in_context(|ctx| { + ctx.strip_query_null_types = strip; + }); +} + /// Reset the state of the thread-local context. /// /// Currently clears: @@ -110,6 +123,9 @@ pub struct GenContext { /// Extract schemas. pub(crate) extract_schemas: bool, + /// Strip null types from query parameter schemas. + pub(crate) strip_query_null_types: bool, + /// Status code for no content. pub(crate) no_content_status: u16, @@ -135,6 +151,7 @@ impl GenContext { infer_responses: true, all_error_responses: false, extract_schemas: true, + strip_query_null_types: true, show_error: default_error_filter, error_handler: None, no_content_status, diff --git a/crates/aide/src/transform.rs b/crates/aide/src/transform.rs index 675e460a..5cf500b5 100644 --- a/crates/aide/src/transform.rs +++ b/crates/aide/src/transform.rs @@ -51,13 +51,15 @@ use std::{any::type_name, marker::PhantomData}; use crate::{ generate::GenContext, openapi::{ - Components, Contact, Info, License, OpenApi, Operation, Parameter, PathItem, ReferenceOr, - Response, SecurityScheme, Server, StatusCode, Tag, + Components, Contact, Info, License, OpenApi, Operation, Parameter, + ParameterSchemaOrContent, PathItem, ReferenceOr, Response, SecurityScheme, Server, + StatusCode, Tag, }, OperationInput, }; use indexmap::IndexMap; use serde::Serialize; +use serde_json::Value; use crate::{ error::Error, generate::in_context, operation::OperationOutput, util::iter_operations_mut, @@ -194,6 +196,23 @@ impl<'t> TransformOpenApi<'t> { self } + /// Strip null types from query parameter schemas. + /// + /// Query strings cannot express null values - a parameter is either + /// present with a value or absent. This method removes `null` from + /// type arrays (e.g., `["string", "null"]` becomes `"string"`) and + /// unwraps `anyOf` variants containing null types. + /// + /// Note: This is called automatically when the transform is finalized + /// (unless disabled via [`aide::generate::strip_query_null_types`]). + /// You only need to call this explicitly if you want to apply it + /// at a specific point in your transform chain. + #[tracing::instrument(skip_all)] + pub fn strip_null_from_query_params(self) -> Self { + strip_null_from_query_params_impl(self.api); + self + } + /// Add a security scheme. #[allow(clippy::missing_panics_doc)] pub fn security_scheme(mut self, name: &str, scheme: SecurityScheme) -> Self { @@ -1320,3 +1339,224 @@ impl<'t> TransformCallback<'t> { fn filter_no_duplicate_response(err: &Error) -> bool { !matches!(err, Error::DefaultResponseExists | Error::ResponseExists(_)) } + +pub(crate) fn strip_null_from_query_params_impl(api: &mut OpenApi) { + let Some(paths) = &mut api.paths else { return }; + + for (_, path_item) in &mut paths.paths { + let ReferenceOr::Item(path_item) = path_item else { + continue; + }; + + for (_, op) in iter_operations_mut(path_item) { + for param in &mut op.parameters { + let ReferenceOr::Item(Parameter::Query { parameter_data, .. }) = param else { + continue; + }; + + let ParameterSchemaOrContent::Schema(schema_obj) = &mut parameter_data.format + else { + continue; + }; + + strip_null_from_type(&mut schema_obj.json_schema); + } + } + } +} + +fn strip_null_from_type(schema: &mut schemars::Schema) { + // Handle type: ["string", "null"] -> type: "string" + if let Some(Value::Array(types)) = schema.get_mut("type") { + let null_count = types.iter().filter(|t| *t == "null").count(); + if null_count == 0 || null_count == types.len() { + return; // No nulls, or all nulls - don't modify + } + types.retain(|t| t != "null"); + if types.len() == 1 { + *schema.get_mut("type").unwrap() = types.remove(0); + } + return; + } + + // Handle anyOf: [..., {type: "null"}] -> remove null variants + let Some(Value::Array(items)) = schema.get_mut("anyOf") else { + return; + }; + + let is_null = |v: &Value| matches!(v.get("type"), Some(Value::String(s)) if s == "null"); + let null_count = items.iter().filter(|item| is_null(item)).count(); + if null_count == 0 || null_count == items.len() { + return; // No nulls, or all nulls - don't modify + } + + items.retain(|item| !is_null(item)); + + // Single item remains: unwrap the anyOf + if items.len() == 1 { + if let Some(Value::Object(obj)) = items.pop() { + if let Some(schema_obj) = schema.as_object_mut() { + schema_obj.remove("anyOf"); + for (key, value) in obj { + schema_obj.insert(key, value); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::openapi::{ + MediaType, OpenApi, Operation, Parameter, ParameterData, ParameterSchemaOrContent, Paths, + ReferenceOr, RequestBody, SchemaObject, + }; + use indexmap::IndexMap; + use schemars::JsonSchema; + use serde_json::json; + + use super::TransformOpenApi; + + fn inline_schema_for() -> schemars::Schema { + let settings = schemars::generate::SchemaSettings::draft07().with(|s| { + s.inline_subschemas = true; + }); + let mut gen = settings.into_generator(); + gen.subschema_for::() + } + + fn property_schema(struct_schema: &schemars::Schema, name: &str) -> schemars::Schema { + struct_schema + .get("properties") + .and_then(|p| p.get(name)) + .unwrap_or_else(|| panic!("property {name:?} not found")) + .clone() + .try_into() + .unwrap() + } + + fn build_api(params: Vec, body_schema: Option) -> OpenApi { + let mut op = Operation::default(); + op.parameters = params.into_iter().map(ReferenceOr::Item).collect(); + if let Some(schema) = body_schema { + op.request_body = Some(ReferenceOr::Item(RequestBody { + content: IndexMap::from_iter([( + "application/json".into(), + MediaType { + schema: Some(SchemaObject { + json_schema: schema, + external_docs: None, + example: None, + }), + ..Default::default() + }, + )]), + ..Default::default() + })); + } + + let mut path_item = crate::openapi::PathItem::default(); + path_item.get = Some(op); + + OpenApi { + paths: Some(Paths { + paths: IndexMap::from([("/test".to_string(), ReferenceOr::Item(path_item))]), + extensions: IndexMap::new(), + }), + ..OpenApi::default() + } + } + + fn query_param(name: &str, schema: schemars::Schema) -> Parameter { + Parameter::Query { + parameter_data: ParameterData { + name: name.to_string(), + description: None, + required: false, + deprecated: None, + format: ParameterSchemaOrContent::Schema(SchemaObject { + json_schema: schema, + external_docs: None, + example: None, + }), + example: None, + examples: IndexMap::new(), + explode: None, + extensions: IndexMap::new(), + }, + allow_reserved: false, + style: Default::default(), + allow_empty_value: None, + } + } + + fn get_param_schema(api: &OpenApi, param_index: usize) -> &schemars::Schema { + let paths = api.paths.as_ref().unwrap(); + let ReferenceOr::Item(path_item) = &paths.paths["/test"] else { + panic!("expected item"); + }; + let op = path_item.get.as_ref().unwrap(); + let ReferenceOr::Item(param) = &op.parameters[param_index] else { + panic!("expected parameter item"); + }; + let ParameterSchemaOrContent::Schema(schema_obj) = ¶m.parameter_data_ref().format + else { + panic!("expected schema"); + }; + &schema_obj.json_schema + } + + fn get_body_schema(api: &OpenApi) -> &schemars::Schema { + let paths = api.paths.as_ref().unwrap(); + let ReferenceOr::Item(path_item) = &paths.paths["/test"] else { + panic!("expected item"); + }; + let op = path_item.get.as_ref().unwrap(); + let Some(ReferenceOr::Item(body)) = &op.request_body else { + panic!("expected request body"); + }; + &body.content["application/json"] + .schema + .as_ref() + .unwrap() + .json_schema + } + + #[test] + fn strip_null_from_query_params() { + #[derive(JsonSchema)] + #[allow(dead_code)] + struct QueryParams { + optional: Option, + } + + #[derive(JsonSchema)] + #[allow(dead_code)] + struct Body { + optional: Option, + } + + let query_field = property_schema(&inline_schema_for::(), "optional"); + let body_field = property_schema(&inline_schema_for::(), "optional"); + + // Both should start out nullable + assert_eq!(query_field.get("type"), Some(&json!(["string", "null"]))); + assert_eq!(body_field.get("type"), Some(&json!(["string", "null"]))); + + let mut api = build_api(vec![query_param("optional", query_field)], Some(body_field)); + let _ = TransformOpenApi::new(&mut api).strip_null_from_query_params(); + + // Query param should have null stripped + assert_eq!( + get_param_schema(&api, 0).get("type"), + Some(&json!("string")) + ); + + // Request body schema should be unchanged + assert_eq!( + get_body_schema(&api).get("type"), + Some(&json!(["string", "null"])), + "request body schema should retain its nullable type" + ); + } +}