Skip to content

Commit 0ca4357

Browse files
committed
support multi value headers and queries
1 parent 34beb4d commit 0ca4357

File tree

4 files changed

+204
-123
lines changed

4 files changed

+204
-123
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,6 @@ def _openapi_operation_parameters(
839839
# Create individual parameter for each model field
840840
param_name = field_def.alias or field_name
841841

842-
# Convert snake_case to kebab-case for headers (HTTP convention)
843-
if isinstance(field_info, Header):
844-
param_name = param_name.replace("_", "-")
845-
846842
individual_param = {
847843
"name": param_name,
848844
"in": field_info.in_.value,

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 35 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_normalize_errors,
1616
_regenerate_error_with_loc,
1717
get_missing_field_error,
18+
lenient_issubclass,
1819
)
1920
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
2021
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
@@ -64,7 +65,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6465
)
6566

6667
# Normalize query values before validate this
67-
query_string = _normalize_multi_query_string_with_param(
68+
query_string = _normalize_multi_params(
6869
app.current_event.resolved_query_string_parameters,
6970
route.dependant.query_params,
7071
)
@@ -76,7 +77,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7677
)
7778

7879
# Normalize header values before validate this
79-
headers = _normalize_multi_header_values_with_param(
80+
headers = _normalize_multi_params(
8081
app.current_event.resolved_headers_field,
8182
route.dependant.header_params,
8283
)
@@ -439,116 +440,43 @@ def _get_embed_body(
439440
return received_body, field_alias_omitted
440441

441442

442-
def _normalize_multi_query_string_with_param(
443-
query_string: dict[str, list[str]],
443+
def _normalize_multi_params(
444+
input_dict: MutableMapping[str, Any],
444445
params: Sequence[ModelField],
445-
) -> dict[str, Any]:
446+
) -> MutableMapping[str, Any]:
446447
"""
447-
Extract and normalize resolved_query_string_parameters with Pydantic model support
448-
449-
Parameters
450-
----------
451-
query_string: dict
452-
A dictionary containing the initial query string parameters.
453-
params: Sequence[ModelField]
454-
A sequence of ModelField objects representing parameters.
455-
456-
Returns
457-
-------
458-
A dictionary containing the processed multi_query_string_parameters.
448+
Generic normalization for query string or header parameters with Pydantic model support.
449+
No key transformation is performed.
459450
"""
460-
resolved_query_string: dict[str, Any] = query_string
461-
462451
for param in params:
463-
# Handle scalar fields (existing logic)
464452
if is_scalar_field(param):
465453
try:
466-
resolved_query_string[param.alias] = query_string[param.alias][0]
454+
val = input_dict[param.alias]
455+
if isinstance(val, list) and len(val) == 1:
456+
input_dict[param.alias] = val[0]
457+
elif isinstance(val, list):
458+
pass # leave as list for multi-value
459+
# If it's a string, leave as is
467460
except KeyError:
468461
pass
469-
# Handle Pydantic models
470-
elif isinstance(param.field_info, Query) and hasattr(param.field_info, "annotation"):
471-
from pydantic import BaseModel
472-
473-
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
474-
475-
if lenient_issubclass(param.field_info.annotation, BaseModel):
476-
model_class = param.field_info.annotation
477-
model_data = {}
478-
479-
# Collect all fields for the Pydantic model
480-
for field_name, field_def in model_class.model_fields.items():
481-
field_alias = field_def.alias or field_name
482-
try:
483-
model_data[field_alias] = query_string[field_alias][0]
484-
except KeyError:
485-
if model_class.model_config.get("validate_by_name") or model_class.model_config.get(
486-
"populate_by_name",
487-
):
488-
try:
489-
model_data[field_alias] = query_string[field_name][0]
490-
except KeyError:
491-
pass
492-
493-
# Store the collected data under the param alias
494-
resolved_query_string[param.alias] = model_data
495-
496-
return resolved_query_string
497-
498-
499-
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
500-
"""
501-
Extract and normalize resolved_headers_field with Pydantic model support
502-
503-
Parameters
504-
----------
505-
headers: MutableMapping[str, Any]
506-
A dictionary containing the initial header parameters.
507-
params: Sequence[ModelField]
508-
A sequence of ModelField objects representing parameters.
509-
510-
Returns
511-
-------
512-
A dictionary containing the processed headers.
513-
"""
514-
if headers:
515-
for param in params:
516-
# Handle scalar fields (existing logic)
517-
if is_scalar_field(param):
518-
try:
519-
if len(headers[param.alias]) == 1:
520-
headers[param.alias] = headers[param.alias][0]
521-
except KeyError:
522-
pass
523-
# Handle Pydantic models
524-
elif isinstance(param.field_info, Header) and hasattr(param.field_info, "annotation"):
525-
from pydantic import BaseModel
526-
527-
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
528-
529-
if lenient_issubclass(param.field_info.annotation, BaseModel):
530-
model_class = param.field_info.annotation
531-
model_data = {}
532-
533-
# Collect all fields for the Pydantic model
534-
for field_name, field_def in model_class.model_fields.items():
535-
field_alias = field_def.alias or field_name
536-
537-
# Convert snake_case to kebab-case for headers (HTTP convention)
538-
header_key = field_alias.replace("_", "-")
539-
540-
try:
541-
header_value = headers[header_key]
542-
if isinstance(header_value, list):
543-
if len(header_value) == 1:
544-
model_data[field_alias] = header_value[0]
545-
else:
546-
model_data[field_alias] = header_value
547-
else:
548-
model_data[field_alias] = header_value
549-
except KeyError:
550-
pass
551-
552-
# Store the collected data under the param alias
553-
headers[param.alias] = model_data
554-
return headers
462+
elif lenient_issubclass(param.field_info.annotation, BaseModel):
463+
model_class = param.field_info.annotation
464+
model_data = {}
465+
from typing import get_origin
466+
467+
for field_name, field_def in model_class.model_fields.items():
468+
field_alias = field_def.alias or field_name
469+
value = input_dict.get(field_alias)
470+
if value is None and (
471+
model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name")
472+
):
473+
value = input_dict.get(field_name)
474+
if value is not None:
475+
if get_origin(field_def.annotation) is list:
476+
model_data[field_alias] = value
477+
elif isinstance(value, list):
478+
model_data[field_alias] = value[0]
479+
else:
480+
model_data[field_alias] = value
481+
input_dict[param.alias] = model_data
482+
return input_dict

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import TYPE_CHECKING, Any, Literal
66

7-
from pydantic import BaseConfig
7+
from pydantic import BaseConfig, BaseModel, create_model
88
from pydantic.fields import FieldInfo
99
from typing_extensions import Annotated, get_args, get_origin
1010

@@ -17,6 +17,7 @@
1717
copy_field_info,
1818
field_annotation_is_scalar,
1919
get_annotation_from_field_info,
20+
lenient_issubclass,
2021
)
2122

2223
if TYPE_CHECKING:
@@ -1094,6 +1095,42 @@ def create_response_field(
10941095
return ModelField(**kwargs) # type: ignore[arg-type]
10951096

10961097

1098+
def _apply_header_underscore_conversion(
1099+
field_info: FieldInfo,
1100+
type_annotation: Any,
1101+
param_name: str,
1102+
) -> tuple[FieldInfo, Any]:
1103+
"""
1104+
Apply underscore-to-dash conversion for Header parameters.
1105+
1106+
For BaseModel: Creates new model with underscore-to-dash alias generator.
1107+
Note: If the BaseModel already has an alias generator, it will be replaced
1108+
with dash-case conversion since HTTP headers should use dash-case.
1109+
For all Header fields: Sets the parameter alias if convert_underscores is True
1110+
"""
1111+
if not isinstance(field_info, Header) or not field_info.convert_underscores:
1112+
return field_info, type_annotation
1113+
1114+
# Always set the parameter alias for Header fields (if not already set)
1115+
if not field_info.alias:
1116+
field_info.alias = param_name.replace("_", "-")
1117+
1118+
# Handle BaseModel case - create new model with dash-case alias generator
1119+
if lenient_issubclass(type_annotation, BaseModel):
1120+
# For HTTP headers, we should use dash-case regardless of existing alias generator
1121+
# This ensures consistent header naming conventions
1122+
header_aliased_model = create_model(
1123+
f"{type_annotation.__name__}WithHeaderAliases",
1124+
__base__=type_annotation,
1125+
__config__={"alias_generator": lambda name: name.replace("_", "-")},
1126+
)
1127+
1128+
type_annotation = header_aliased_model
1129+
field_info.annotation = type_annotation
1130+
1131+
return field_info, type_annotation
1132+
1133+
10971134
def _create_model_field(
10981135
field_info: FieldInfo | None,
10991136
type_annotation: Any,
@@ -1112,21 +1149,17 @@ def _create_model_field(
11121149
elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None:
11131150
field_info.in_ = ParamTypes.query
11141151

1152+
# Apply header underscore conversion
1153+
field_info, type_annotation = _apply_header_underscore_conversion(field_info, type_annotation, param_name)
1154+
11151155
# If the field_info is a Param, we use the `in_` attribute to determine the type annotation
11161156
use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name)
11171157

1118-
# If the field doesn't have a defined alias, we use the param name
1119-
if not field_info.alias and getattr(field_info, "convert_underscores", None):
1120-
alias = param_name.replace("_", "-")
1121-
else:
1122-
alias = field_info.alias or param_name
1123-
field_info.alias = alias
1124-
11251158
return create_response_field(
11261159
name=param_name,
11271160
type_=use_annotation,
11281161
default=field_info.default,
1129-
alias=alias,
1162+
alias=field_info.alias,
11301163
required=field_info.default in (Required, Undefined),
11311164
field_info=field_info,
11321165
)

0 commit comments

Comments
 (0)