Skip to content

Commit 6873239

Browse files
committed
add union type and subtypes check in schema model signature
1 parent d009420 commit 6873239

File tree

1 file changed

+48
-19
lines changed

1 file changed

+48
-19
lines changed

ninja/signature/details.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
208208

209209
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
210210
field: FieldInfo
211-
for attr, field in model.model_fields.items():
212-
field_name = field.alias or attr
213-
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
214-
if is_pydantic_model(field.annotation):
215-
yield from self._model_flatten_map(field.annotation, name) # type: ignore
216-
else:
217-
yield field_name, name
211+
if get_origin(model) in UNION_TYPES:
212+
# If the model is a union type, process each type in the union
213+
for arg in get_args(model):
214+
if arg is type(None):
215+
continue # Skip NoneType
216+
yield from self._model_flatten_map(arg, prefix)
217+
else:
218+
for attr, field in model.model_fields.items():
219+
field_name = field.alias or attr
220+
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
221+
if is_pydantic_model(field.annotation):
222+
yield from self._model_flatten_map(field.annotation, name) # type: ignore
223+
else:
224+
yield field_name, name
218225

219226
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
220227
# _EMPTY = self.signature.empty
@@ -276,9 +283,9 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
276283

277284
# 2) if param name is a part of the path parameter
278285
elif name in self.path_params_names:
279-
assert (
280-
default == self.signature.empty
281-
), f"'{name}' is a path param, default not allowed"
286+
assert default == self.signature.empty, (
287+
f"'{name}' is a path param, default not allowed"
288+
)
282289
param_source = Path(...)
283290

284291
# 3) if param is a collection, or annotation is part of pydantic model:
@@ -311,7 +318,11 @@ def is_pydantic_model(cls: Any) -> bool:
311318

312319
# Handle Union types
313320
if origin in UNION_TYPES:
314-
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
321+
return any(
322+
issubclass(arg, pydantic.BaseModel)
323+
for arg in get_args(cls)
324+
if arg is not type(None)
325+
)
315326
return issubclass(cls, pydantic.BaseModel)
316327
except TypeError: # pragma: no cover
317328
return False
@@ -354,14 +365,32 @@ def detect_collection_fields(
354365
for attr in path[1:]:
355366
if hasattr(annotation_or_field, "annotation"):
356367
annotation_or_field = annotation_or_field.annotation
357-
annotation_or_field = next(
358-
(
359-
a
360-
for a in annotation_or_field.model_fields.values()
361-
if a.alias == attr
362-
),
363-
annotation_or_field.model_fields.get(attr),
364-
) # pragma: no cover
368+
369+
# check union types
370+
if get_origin(annotation_or_field) in UNION_TYPES:
371+
for arg in get_args(annotation_or_field):
372+
if arg is type(None):
373+
continue # Skip NoneType
374+
if hasattr(arg, "model_fields"):
375+
annotation_or_field = next(
376+
(
377+
a
378+
for a in arg.model_fields.values()
379+
if a.alias == attr
380+
),
381+
arg.model_fields.get(attr),
382+
) # pragma: no cover
383+
else:
384+
continue
385+
else:
386+
annotation_or_field = next(
387+
(
388+
a
389+
for a in annotation_or_field.model_fields.values()
390+
if a.alias == attr
391+
),
392+
annotation_or_field.model_fields.get(attr),
393+
) # pragma: no cover
365394

366395
annotation_or_field = getattr(
367396
annotation_or_field, "outer_type_", annotation_or_field

0 commit comments

Comments
 (0)