Skip to content

Commit d07aac3

Browse files
committed
Bugfixes and testing
1 parent 4f3a9d9 commit d07aac3

File tree

2 files changed

+383
-6
lines changed

2 files changed

+383
-6
lines changed

ninja/signature/details.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
187187
flatten_map = {}
188188
arg_names: Any = {}
189189
for arg in args:
190+
# Check if this is an optional union type with None default
191+
if get_origin(arg.annotation) in UNION_TYPES:
192+
union_args = get_args(arg.annotation)
193+
has_none = type(None) in union_args
194+
# If it's a union with None and the source default is None (like Query(None)), don't flatten it
195+
if has_none and hasattr(arg.source, 'default') and arg.source.default is None:
196+
name = arg.alias
197+
if name in flatten_map:
198+
raise ConfigError(
199+
f"Duplicated name: '{name}' also in '{arg_names[name]}'"
200+
)
201+
flatten_map[name] = (name,)
202+
arg_names[name] = name
203+
continue
204+
190205
if is_pydantic_model(arg.annotation):
191206
for name, path in self._model_flatten_map(arg.annotation, arg.alias):
192207
if name in flatten_map:
@@ -218,7 +233,30 @@ def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
218233
for attr, field in model.model_fields.items():
219234
field_name = field.alias or attr
220235
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
221-
if is_pydantic_model(field.annotation):
236+
237+
# Check if this is a union type field
238+
if get_origin(field.annotation) in UNION_TYPES:
239+
union_args = get_args(field.annotation)
240+
has_none = type(None) in union_args
241+
non_none_args = [arg for arg in union_args if arg is not type(None)]
242+
243+
# If it's an optional field (Union with None) and has a default value,
244+
# don't flatten it - treat it as a single optional field
245+
if has_none and field.default is not PydanticUndefined:
246+
yield field_name, name
247+
continue
248+
249+
# For non-optional unions or unions without defaults,
250+
# check if any of the union args are pydantic models
251+
pydantic_args = [arg for arg in non_none_args if is_pydantic_model(arg)]
252+
if pydantic_args:
253+
# Process only the pydantic model types
254+
for arg in pydantic_args:
255+
yield from self._model_flatten_map(arg, name)
256+
else:
257+
# No pydantic models in union, treat as simple field
258+
yield field_name, name
259+
elif is_pydantic_model(field.annotation):
222260
yield from self._model_flatten_map(field.annotation, name) # type: ignore
223261
else:
224262
yield field_name, name
@@ -368,20 +406,27 @@ def detect_collection_fields(
368406

369407
# check union types
370408
if get_origin(annotation_or_field) in UNION_TYPES:
409+
found = False
371410
for arg in get_args(annotation_or_field):
372411
if arg is type(None):
373412
continue # Skip NoneType
374413
if hasattr(arg, "model_fields"):
375-
annotation_or_field = next(
414+
found_field = next(
376415
(
377416
a
378417
for a in arg.model_fields.values()
379418
if a.alias == attr
380419
),
381420
arg.model_fields.get(attr),
382-
) # pragma: no cover
383-
else:
384-
continue
421+
)
422+
if found_field is not None:
423+
annotation_or_field = found_field
424+
found = True
425+
break
426+
if not found:
427+
# No suitable field found in any union member, skip this path
428+
annotation_or_field = None
429+
break # Break out of the attr loop
385430
else:
386431
annotation_or_field = next(
387432
(
@@ -396,8 +441,13 @@ def detect_collection_fields(
396441
annotation_or_field, "outer_type_", annotation_or_field
397442
)
398443

444+
# Skip if annotation_or_field is None (e.g., from failed union processing)
445+
if annotation_or_field is None:
446+
continue
447+
399448
# if hasattr(annotation_or_field, "annotation"):
400-
annotation_or_field = annotation_or_field.annotation
449+
if hasattr(annotation_or_field, "annotation"):
450+
annotation_or_field = annotation_or_field.annotation
401451

402452
if is_collection_type(annotation_or_field):
403453
result.append(path[-1])

0 commit comments

Comments
 (0)