Skip to content

Commit 3aaf6df

Browse files
leandrodamascenatonnico
authored andcommitted
Improving field validation method + tests
1 parent 91cbeb5 commit 3aaf6df

File tree

2 files changed

+195
-80
lines changed

2 files changed

+195
-80
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 37 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -330,63 +330,21 @@ def _request_params_to_args_with_pydantic_support(
330330
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
331331

332332
if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel):
333-
# Handle Pydantic model
334-
model_class = field_info.annotation
335-
model_data = {}
336-
model_errors = []
337-
338-
# Extract individual fields from the request
339-
for model_field_name, model_field_def in model_class.model_fields.items():
340-
field_alias = model_field_def.alias or model_field_name
341-
342-
# Convert snake_case to kebab-case for headers (HTTP convention)
343-
if isinstance(field_info, Header):
344-
field_alias = field_alias.replace("_", "-")
345-
346-
field_value = received_params.get(field_alias)
347-
348-
if field_value is not None:
349-
model_data[model_field_name] = field_value
350-
elif (
351-
model_field_def.is_required()
352-
if hasattr(model_field_def, "is_required")
353-
else model_field_def.default is ...
354-
):
355-
# Required field missing
356-
loc = (field_info.in_.value, field_alias)
357-
model_errors.append(get_missing_field_error(loc=loc))
358-
359-
if model_errors:
360-
errors.extend(model_errors)
361-
else:
362-
# Try to create the Pydantic model
363-
try:
364-
model_instance = model_class(**model_data)
365-
values[field.name] = model_instance
366-
except ValidationError as e:
367-
# Extract detailed validation errors from Pydantic
368-
for error in e.errors():
369-
# Update the location to include the parameter source (query/header) and field path
370-
error_loc = [field_info.in_.value] + list(error["loc"])
371-
errors.append(
372-
{
373-
"type": error["type"],
374-
"loc": error_loc,
375-
"msg": error["msg"],
376-
"input": error.get("input"),
377-
},
378-
)
379-
except Exception as e:
380-
# Fallback for non-Pydantic validation errors
381-
loc = (field_info.in_.value, field.alias)
382-
errors.append(
383-
{
384-
"type": "value_error",
385-
"loc": loc,
386-
"msg": str(e),
387-
"input": model_data,
388-
},
389-
)
333+
# Handle Pydantic model - use the same approach as _request_body_to_args
334+
loc = (field_info.in_.value, field.alias)
335+
336+
# Get the raw data for the Pydantic model
337+
value = received_params.get(field.alias)
338+
339+
if value is None:
340+
if field.required:
341+
errors.append(get_missing_field_error(loc))
342+
else:
343+
values[field.name] = deepcopy(field.default)
344+
continue
345+
346+
# Use _validate_field like _request_body_to_args does
347+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
390348
else:
391349
# Regular parameter processing (existing logic)
392350
if not isinstance(field_info, Param):
@@ -565,14 +523,19 @@ def _normalize_multi_query_string_with_param(
565523

566524
if lenient_issubclass(param.field_info.annotation, BaseModel):
567525
model_class = param.field_info.annotation
568-
# Normalize individual fields of the Pydantic model
526+
model_data = {}
527+
528+
# Collect all fields for the Pydantic model
569529
for field_name, field_def in model_class.model_fields.items():
570530
field_alias = field_def.alias or field_name
571531
try:
572-
resolved_query_string[field_alias] = query_string[field_alias][0]
532+
model_data[field_alias] = query_string[field_alias][0]
573533
except KeyError:
574534
pass
575535

536+
# Store the collected data under the param alias
537+
resolved_query_string[param.alias] = model_data
538+
576539
return resolved_query_string
577540

578541

@@ -608,16 +571,27 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
608571

609572
if lenient_issubclass(param.field_info.annotation, BaseModel):
610573
model_class = param.field_info.annotation
611-
# Normalize individual fields of the Pydantic model
574+
model_data = {}
575+
576+
# Collect all fields for the Pydantic model
612577
for field_name, field_def in model_class.model_fields.items():
613578
field_alias = field_def.alias or field_name
614579

615580
# Convert snake_case to kebab-case for headers (HTTP convention)
616-
field_alias = field_alias.replace("_", "-")
581+
header_key = field_alias.replace("_", "-")
617582

618583
try:
619-
if len(headers[field_alias]) == 1:
620-
headers[field_alias] = headers[field_alias][0]
584+
header_value = headers[header_key]
585+
if isinstance(header_value, list):
586+
if len(header_value) == 1:
587+
model_data[field_alias] = header_value[0]
588+
else:
589+
model_data[field_alias] = header_value
590+
else:
591+
model_data[field_alias] = header_value
621592
except KeyError:
622593
pass
594+
595+
# Store the collected data under the param alias
596+
headers[param.alias] = model_data
623597
return headers

0 commit comments

Comments
 (0)