@@ -207,13 +207,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
207207
208208 def _model_flatten_map (self , model : TModel , prefix : str ) -> Generator :
209209 field : FieldInfo
210- for attr , field in model .model_fields .items ():
211- field_name = field .alias or attr
212- name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
213- if is_pydantic_model (field .annotation ):
214- yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
215- else :
216- yield field_name , name
210+ if get_origin (model ) in UNION_TYPES :
211+ # If the model is a union type, process each type in the union
212+ for arg in get_args (model ):
213+ if arg is type (None ):
214+ continue # Skip NoneType
215+ yield from self ._model_flatten_map (arg , prefix )
216+ else :
217+ for attr , field in model .model_fields .items ():
218+ field_name = field .alias or attr
219+ name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
220+ if is_pydantic_model (field .annotation ):
221+ yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
222+ else :
223+ yield field_name , name
217224
218225 def _get_param_type (self , name : str , arg : inspect .Parameter ) -> FuncParam :
219226 # _EMPTY = self.signature.empty
@@ -260,9 +267,9 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
260267
261268 # 2) if param name is a part of the path parameter
262269 elif name in self .path_params_names :
263- assert (
264- default == self . signature . empty
265- ), f"' { name } ' is a path param, default not allowed"
270+ assert default == self . signature . empty , (
271+ f"' { name } ' is a path param, default not allowed"
272+ )
266273 param_source = Path (...)
267274
268275 # 3) if param is a collection, or annotation is part of pydantic model:
@@ -295,7 +302,11 @@ def is_pydantic_model(cls: Any) -> bool:
295302
296303 # Handle Union types
297304 if origin in UNION_TYPES :
298- return any (issubclass (arg , pydantic .BaseModel ) for arg in get_args (cls ))
305+ return any (
306+ issubclass (arg , pydantic .BaseModel )
307+ for arg in get_args (cls )
308+ if arg is not type (None )
309+ )
299310 return issubclass (cls , pydantic .BaseModel )
300311 except TypeError : # pragma: no cover
301312 return False
@@ -338,14 +349,32 @@ def detect_collection_fields(
338349 for attr in path [1 :]:
339350 if hasattr (annotation_or_field , "annotation" ):
340351 annotation_or_field = annotation_or_field .annotation
341- annotation_or_field = next (
342- (
343- a
344- for a in annotation_or_field .model_fields .values ()
345- if a .alias == attr
346- ),
347- annotation_or_field .model_fields .get (attr ),
348- ) # pragma: no cover
352+
353+ # check union types
354+ if get_origin (annotation_or_field ) in UNION_TYPES :
355+ for arg in get_args (annotation_or_field ):
356+ if arg is type (None ):
357+ continue # Skip NoneType
358+ if hasattr (arg , "model_fields" ):
359+ annotation_or_field = next (
360+ (
361+ a
362+ for a in arg .model_fields .values ()
363+ if a .alias == attr
364+ ),
365+ arg .model_fields .get (attr ),
366+ ) # pragma: no cover
367+ else :
368+ continue
369+ else :
370+ annotation_or_field = next (
371+ (
372+ a
373+ for a in annotation_or_field .model_fields .values ()
374+ if a .alias == attr
375+ ),
376+ annotation_or_field .model_fields .get (attr ),
377+ ) # pragma: no cover
349378
350379 annotation_or_field = getattr (
351380 annotation_or_field , "outer_type_" , annotation_or_field
0 commit comments