|
4 | 4 | import json |
5 | 5 | import logging |
6 | 6 | from copy import deepcopy |
7 | | -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, get_origin |
| 7 | +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin |
8 | 8 | from urllib.parse import parse_qs |
9 | 9 |
|
10 | 10 | from pydantic import BaseModel |
@@ -456,32 +456,59 @@ def _normalize_multi_params( |
456 | 456 | """ |
457 | 457 | for param in params: |
458 | 458 | if is_scalar_field(param): |
459 | | - try: |
460 | | - val = input_dict[param.alias] |
461 | | - if isinstance(val, list) and len(val) == 1: |
462 | | - input_dict[param.alias] = val[0] |
463 | | - elif isinstance(val, list): |
464 | | - pass # leave as list for multi-value |
465 | | - # If it's a string, leave as is |
466 | | - except KeyError: |
467 | | - pass |
| 459 | + _process_scalar_param(input_dict, param) |
468 | 460 | elif lenient_issubclass(param.field_info.annotation, BaseModel): |
469 | | - model_class = param.field_info.annotation |
470 | | - model_data = {} |
471 | | - |
472 | | - for field_name, field_def in model_class.model_fields.items(): |
473 | | - field_alias = field_def.alias or field_name |
474 | | - value = input_dict.get(field_alias) |
475 | | - if value is None and ( |
476 | | - model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name") |
477 | | - ): |
478 | | - value = input_dict.get(field_name) |
479 | | - if value is not None: |
480 | | - if get_origin(field_def.annotation) is list: |
481 | | - model_data[field_alias] = value |
482 | | - elif isinstance(value, list): |
483 | | - model_data[field_alias] = value[0] |
484 | | - else: |
485 | | - model_data[field_alias] = value |
486 | | - input_dict[param.alias] = model_data |
| 461 | + _process_model_param(input_dict, param) |
487 | 462 | return input_dict |
| 463 | + |
| 464 | + |
| 465 | +def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: |
| 466 | + """Process a scalar parameter by normalizing single-item lists.""" |
| 467 | + try: |
| 468 | + val = input_dict[param.alias] |
| 469 | + if isinstance(val, list) and len(val) == 1: |
| 470 | + input_dict[param.alias] = val[0] |
| 471 | + except KeyError: |
| 472 | + pass |
| 473 | + |
| 474 | + |
| 475 | +def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: |
| 476 | + """Process a Pydantic model parameter by extracting model fields.""" |
| 477 | + model_class = cast(type[BaseModel], param.field_info.annotation) |
| 478 | + |
| 479 | + model_data = {} |
| 480 | + for field_name, field_def in model_class.model_fields.items(): |
| 481 | + field_alias = field_def.alias or field_name |
| 482 | + value = _get_param_value(input_dict, field_alias, field_name, model_class) |
| 483 | + |
| 484 | + if value is not None: |
| 485 | + model_data[field_alias] = _normalize_field_value(value, field_def) |
| 486 | + |
| 487 | + input_dict[param.alias] = model_data |
| 488 | + |
| 489 | + |
| 490 | +def _get_param_value( |
| 491 | + input_dict: MutableMapping[str, Any], |
| 492 | + field_alias: str, |
| 493 | + field_name: str, |
| 494 | + model_class: type[BaseModel], |
| 495 | +) -> Any: |
| 496 | + """Get parameter value, checking both alias and field name if needed.""" |
| 497 | + value = input_dict.get(field_alias) |
| 498 | + if value is not None: |
| 499 | + return value |
| 500 | + |
| 501 | + if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"): |
| 502 | + value = input_dict.get(field_name) |
| 503 | + |
| 504 | + return value |
| 505 | + |
| 506 | + |
| 507 | +def _normalize_field_value(value: Any, field_def: Any) -> Any: |
| 508 | + """Normalize field value based on its type annotation.""" |
| 509 | + if get_origin(field_def.annotation) is list: |
| 510 | + return value |
| 511 | + elif isinstance(value, list) and value: |
| 512 | + return value[0] |
| 513 | + else: |
| 514 | + return value |
0 commit comments