Skip to content

Commit b3a0387

Browse files
leandrodamascenatonnico
authored andcommitted
Adding supoort for Pydantic models in Query and Header
1 parent 81b50ed commit b3a0387

File tree

5 files changed

+736
-42
lines changed

5 files changed

+736
-42
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ def _openapi_operation_parameters(
815815
from aws_lambda_powertools.event_handler.openapi.compat import (
816816
get_schema_from_model_field,
817817
)
818-
from aws_lambda_powertools.event_handler.openapi.params import Param
818+
from aws_lambda_powertools.event_handler.openapi.params import Form, Header, Param, Query
819819

820820
parameters = []
821821
parameter: dict[str, Any] = {}
@@ -826,32 +826,74 @@ def _openapi_operation_parameters(
826826
if not field_info.include_in_schema:
827827
continue
828828

829-
param_schema = get_schema_from_model_field(
830-
field=param,
831-
model_name_map=model_name_map,
832-
field_mapping=field_mapping,
833-
)
829+
# Check if this is a Pydantic model that should be expanded
830+
from pydantic import BaseModel
834831

835-
parameter = {
836-
"name": param.alias,
837-
"in": field_info.in_.value,
838-
"required": param.required,
839-
"schema": param_schema,
840-
}
832+
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
841833

842-
if field_info.description:
843-
parameter["description"] = field_info.description
834+
if isinstance(field_info, (Query, Header, Form)) and lenient_issubclass(field_info.annotation, BaseModel):
835+
# Expand Pydantic model into individual parameters
836+
model_class = field_info.annotation
844837

845-
if field_info.openapi_examples:
846-
parameter["examples"] = field_info.openapi_examples
838+
for field_name, field_def in model_class.model_fields.items():
839+
# Create individual parameter for each model field
840+
individual_param = {
841+
"name": field_def.alias or field_name,
842+
"in": field_info.in_.value,
843+
"required": field_def.is_required()
844+
if hasattr(field_def, "is_required")
845+
else field_def.default is ...,
846+
"schema": Route._get_basic_type_schema(field_def.annotation),
847+
}
848+
849+
if field_def.description:
850+
individual_param["description"] = field_def.description
851+
852+
parameters.append(individual_param)
853+
else:
854+
# Regular parameter processing
855+
param_schema = get_schema_from_model_field(
856+
field=param,
857+
model_name_map=model_name_map,
858+
field_mapping=field_mapping,
859+
)
847860

848-
if field_info.deprecated:
849-
parameter["deprecated"] = field_info.deprecated
861+
parameter = {
862+
"name": param.alias,
863+
"in": field_info.in_.value,
864+
"required": param.required,
865+
"schema": param_schema,
866+
}
850867

851-
parameters.append(parameter)
868+
if field_info.description:
869+
parameter["description"] = field_info.description
870+
871+
if field_info.openapi_examples:
872+
parameter["examples"] = field_info.openapi_examples
873+
874+
if field_info.deprecated:
875+
parameter["deprecated"] = field_info.deprecated
876+
877+
parameters.append(parameter)
852878

853879
return parameters
854880

881+
@staticmethod
882+
def _get_basic_type_schema(param_type: type) -> dict[str, str]:
883+
"""
884+
Get basic OpenAPI schema for simple types
885+
"""
886+
if isinstance(int, param_type):
887+
return {"type": "integer"}
888+
elif isinstance(float, param_type):
889+
return {"type": "number"}
890+
elif isinstance(bool, param_type):
891+
return {"type": "boolean"}
892+
elif isinstance(str, param_type):
893+
return {"type": "string"}
894+
else:
895+
return {"type": "string"} # Default fallback
896+
855897
@staticmethod
856898
def _openapi_operation_return(
857899
*,

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 133 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
2020
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2121
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
22-
from aws_lambda_powertools.event_handler.openapi.params import Param
22+
from aws_lambda_powertools.event_handler.openapi.params import Header, Param, Query
2323

2424
if TYPE_CHECKING:
2525
from aws_lambda_powertools.event_handler import Response
@@ -69,8 +69,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6969
route.dependant.query_params,
7070
)
7171

72-
# Process query values
73-
query_values, query_errors = _request_params_to_args(
72+
# Process query values (with Pydantic model support)
73+
query_values, query_errors = _request_params_to_args_with_pydantic_support(
7474
route.dependant.query_params,
7575
query_string,
7676
)
@@ -81,8 +81,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
8181
route.dependant.header_params,
8282
)
8383

84-
# Process header values
85-
header_values, header_errors = _request_params_to_args(
84+
# Process header values (with Pydantic model support)
85+
header_values, header_errors = _request_params_to_args_with_pydantic_support(
8686
route.dependant.header_params,
8787
headers,
8888
)
@@ -311,6 +311,84 @@ def _prepare_response_content(
311311
return res # pragma: no cover
312312

313313

314+
def _request_params_to_args_with_pydantic_support(
315+
required_params: Sequence[ModelField],
316+
received_params: Mapping[str, Any],
317+
) -> tuple[dict[str, Any], list[Any]]:
318+
"""
319+
Convert request params to a dictionary of values with Pydantic model support.
320+
"""
321+
values = {}
322+
errors = []
323+
324+
for field in required_params:
325+
field_info = field.field_info
326+
327+
# Check if this is a Pydantic model in Query/Header
328+
from pydantic import BaseModel
329+
330+
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
331+
332+
if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel):
333+
# Handle Pydantic model
334+
model_class = field_info.annotation
335+
model_data = {}
336+
model_errors = []
337+
338+
# Extract individual fields from the request
339+
for model_field_name, model_field_def in model_class.model_fields.items():
340+
field_alias = model_field_def.alias or model_field_name
341+
field_value = received_params.get(field_alias)
342+
343+
if field_value is not None:
344+
model_data[model_field_name] = field_value
345+
elif (
346+
model_field_def.is_required()
347+
if hasattr(model_field_def, "is_required")
348+
else model_field_def.default is ...
349+
):
350+
# Required field missing
351+
loc = (field_info.in_.value, field_alias)
352+
model_errors.append(get_missing_field_error(loc=loc))
353+
354+
if model_errors:
355+
errors.extend(model_errors)
356+
else:
357+
# Try to create the Pydantic model
358+
try:
359+
model_instance = model_class(**model_data)
360+
values[field.name] = model_instance
361+
except Exception as e:
362+
# Validation error
363+
loc = (field_info.in_.value, field.alias)
364+
errors.append(
365+
{
366+
"type": "value_error",
367+
"loc": loc,
368+
"msg": str(e),
369+
"input": model_data,
370+
},
371+
)
372+
else:
373+
# Regular parameter processing (existing logic)
374+
if not isinstance(field_info, Param):
375+
raise AssertionError(f"Expected Param field_info, got {field_info}")
376+
377+
value = received_params.get(field.alias)
378+
loc = (field_info.in_.value, field.alias)
379+
380+
if value is None:
381+
if field.required:
382+
errors.append(get_missing_field_error(loc=loc))
383+
else:
384+
values[field.name] = deepcopy(field.default)
385+
continue
386+
387+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
388+
389+
return values, errors
390+
391+
314392
def _request_params_to_args(
315393
required_params: Sequence[ModelField],
316394
received_params: Mapping[str, Any],
@@ -439,7 +517,7 @@ def _normalize_multi_query_string_with_param(
439517
params: Sequence[ModelField],
440518
) -> dict[str, Any]:
441519
"""
442-
Extract and normalize resolved_query_string_parameters
520+
Extract and normalize resolved_query_string_parameters with Pydantic model support
443521
444522
Parameters
445523
----------
@@ -453,19 +531,36 @@ def _normalize_multi_query_string_with_param(
453531
A dictionary containing the processed multi_query_string_parameters.
454532
"""
455533
resolved_query_string: dict[str, Any] = query_string
456-
for param in filter(is_scalar_field, params):
457-
try:
458-
# if the target parameter is a scalar, we keep the first value of the query string
459-
# regardless if there are more in the payload
460-
resolved_query_string[param.alias] = query_string[param.alias][0]
461-
except KeyError:
462-
pass
534+
535+
for param in params:
536+
# Handle scalar fields (existing logic)
537+
if is_scalar_field(param):
538+
try:
539+
resolved_query_string[param.alias] = query_string[param.alias][0]
540+
except KeyError:
541+
pass
542+
# Handle Pydantic models
543+
elif isinstance(param.field_info, Query) and hasattr(param.field_info, "annotation"):
544+
from pydantic import BaseModel
545+
546+
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
547+
548+
if lenient_issubclass(param.field_info.annotation, BaseModel):
549+
model_class = param.field_info.annotation
550+
# Normalize individual fields of the Pydantic model
551+
for field_name, field_def in model_class.model_fields.items():
552+
field_alias = field_def.alias or field_name
553+
try:
554+
resolved_query_string[field_alias] = query_string[field_alias][0]
555+
except KeyError:
556+
pass
557+
463558
return resolved_query_string
464559

465560

466561
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
467562
"""
468-
Extract and normalize resolved_headers_field
563+
Extract and normalize resolved_headers_field with Pydantic model support
469564
470565
Parameters
471566
----------
@@ -479,12 +574,28 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
479574
A dictionary containing the processed headers.
480575
"""
481576
if headers:
482-
for param in filter(is_scalar_field, params):
483-
try:
484-
if len(headers[param.alias]) == 1:
485-
# if the target parameter is a scalar and the list contains only 1 element
486-
# we keep the first value of the headers regardless if there are more in the payload
487-
headers[param.alias] = headers[param.alias][0]
488-
except KeyError:
489-
pass
577+
for param in params:
578+
# Handle scalar fields (existing logic)
579+
if is_scalar_field(param):
580+
try:
581+
if len(headers[param.alias]) == 1:
582+
headers[param.alias] = headers[param.alias][0]
583+
except KeyError:
584+
pass
585+
# Handle Pydantic models
586+
elif isinstance(param.field_info, Header) and hasattr(param.field_info, "annotation"):
587+
from pydantic import BaseModel
588+
589+
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
590+
591+
if lenient_issubclass(param.field_info.annotation, BaseModel):
592+
model_class = param.field_info.annotation
593+
# Normalize individual fields of the Pydantic model
594+
for field_name, field_def in model_class.model_fields.items():
595+
field_alias = field_def.alias or field_name
596+
try:
597+
if len(headers[field_alias]) == 1:
598+
headers[field_alias] = headers[field_alias][0]
599+
except KeyError:
600+
pass
490601
return headers

0 commit comments

Comments
 (0)