@@ -208,13 +208,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
208208
209209 def _model_flatten_map (self , model : TModel , prefix : str ) -> Generator :
210210 field : FieldInfo
211- for attr , field in model .model_fields .items ():
212- field_name = field .alias or attr
213- name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
214- if is_pydantic_model (field .annotation ):
215- yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
216- else :
217- yield field_name , name
211+ if get_origin (model ) in UNION_TYPES :
212+ # If the model is a union type, process each type in the union
213+ for arg in get_args (model ):
214+ if arg is type (None ):
215+ continue # Skip NoneType
216+ yield from self ._model_flatten_map (arg , prefix )
217+ else :
218+ for attr , field in model .model_fields .items ():
219+ field_name = field .alias or attr
220+ name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
221+ if is_pydantic_model (field .annotation ):
222+ yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
223+ else :
224+ yield field_name , name
218225
219226 def _get_param_type (self , name : str , arg : inspect .Parameter ) -> FuncParam :
220227 # _EMPTY = self.signature.empty
@@ -276,9 +283,9 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
276283
277284 # 2) if param name is a part of the path parameter
278285 elif name in self .path_params_names :
279- assert (
280- default == self . signature . empty
281- ), f"' { name } ' is a path param, default not allowed"
286+ assert default == self . signature . empty , (
287+ f"' { name } ' is a path param, default not allowed"
288+ )
282289 param_source = Path (...)
283290
284291 # 3) if param is a collection, or annotation is part of pydantic model:
@@ -311,7 +318,11 @@ def is_pydantic_model(cls: Any) -> bool:
311318
312319 # Handle Union types
313320 if origin in UNION_TYPES :
314- return any (issubclass (arg , pydantic .BaseModel ) for arg in get_args (cls ))
321+ return any (
322+ issubclass (arg , pydantic .BaseModel )
323+ for arg in get_args (cls )
324+ if arg is not type (None )
325+ )
315326 return issubclass (cls , pydantic .BaseModel )
316327 except TypeError : # pragma: no cover
317328 return False
@@ -354,14 +365,32 @@ def detect_collection_fields(
354365 for attr in path [1 :]:
355366 if hasattr (annotation_or_field , "annotation" ):
356367 annotation_or_field = annotation_or_field .annotation
357- annotation_or_field = next (
358- (
359- a
360- for a in annotation_or_field .model_fields .values ()
361- if a .alias == attr
362- ),
363- annotation_or_field .model_fields .get (attr ),
364- ) # pragma: no cover
368+
369+ # check union types
370+ if get_origin (annotation_or_field ) in UNION_TYPES :
371+ for arg in get_args (annotation_or_field ):
372+ if arg is type (None ):
373+ continue # Skip NoneType
374+ if hasattr (arg , "model_fields" ):
375+ annotation_or_field = next (
376+ (
377+ a
378+ for a in arg .model_fields .values ()
379+ if a .alias == attr
380+ ),
381+ arg .model_fields .get (attr ),
382+ ) # pragma: no cover
383+ else :
384+ continue
385+ else :
386+ annotation_or_field = next (
387+ (
388+ a
389+ for a in annotation_or_field .model_fields .values ()
390+ if a .alias == attr
391+ ),
392+ annotation_or_field .model_fields .get (attr ),
393+ ) # pragma: no cover
365394
366395 annotation_or_field = getattr (
367396 annotation_or_field , "outer_type_" , annotation_or_field
0 commit comments