@@ -187,8 +187,13 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
187187 arg_names : Any = {}
188188 for arg in args :
189189 if is_pydantic_model (arg .annotation ):
190- for name , path in self ._model_flatten_map (arg .annotation , arg .alias ):
191- if name in flatten_map :
190+ for name , path , is_union_descendant in self ._model_flatten_map (
191+ arg .annotation , arg .alias
192+ ):
193+ model = arg .annotation
194+ if get_origin (model ) is Annotated :
195+ model = get_args (model )[0 ]
196+ if not is_union_descendant and name in flatten_map :
192197 raise ConfigError (
193198 f"Duplicated name: '{ name } ' in params: '{ arg_names [name ]} ' & '{ arg .name } '"
194199 )
@@ -205,15 +210,26 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
205210
206211 return flatten_map
207212
208- def _model_flatten_map (self , model : TModel , prefix : str ) -> Generator :
213+ def _model_flatten_map (
214+ self , model : TModel , prefix : str , is_union_descendant : bool = False
215+ ) -> Generator [Tuple [str , str , bool ], None , None ]:
209216 field : FieldInfo
210- for attr , field in model .model_fields .items ():
211- field_name = field .alias or attr
212- name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
213- if is_pydantic_model (field .annotation ):
214- yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
215- else :
216- yield field_name , name
217+ if get_origin (model ) is Annotated :
218+ model = get_args (model )[0 ]
219+ if get_origin (model ) in UNION_TYPES :
220+ # If the model is a union type, process each type in the union
221+ for arg in get_args (model ):
222+ yield from self ._model_flatten_map (arg , prefix , True )
223+ else :
224+ for attr , field in model .model_fields .items ():
225+ field_name = field .alias or attr
226+ name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
227+ if is_pydantic_model (field .annotation ):
228+ yield from self ._model_flatten_map (
229+ field .annotation , name , is_union_descendant
230+ ) # type: ignore
231+ else :
232+ yield field_name , name , is_union_descendant
217233
218234 def _get_param_type (self , name : str , arg : inspect .Parameter ) -> FuncParam :
219235 # _EMPTY = self.signature.empty
@@ -336,24 +352,40 @@ def detect_collection_fields(
336352 for path in (p for p in flatten_map .values () if len (p ) > 1 ):
337353 annotation_or_field : Any = args_d [path [0 ]].annotation
338354 for attr in path [1 :]:
339- if hasattr (annotation_or_field , "annotation" ):
340- annotation_or_field = annotation_or_field .annotation
341- annotation_or_field = next (
342- (
343- a
344- for a in annotation_or_field .model_fields .values ()
345- if a .alias == attr
346- ),
347- annotation_or_field .model_fields .get (attr ),
348- ) # pragma: no cover
355+ if get_origin (annotation_or_field ) is Annotated :
356+ annotation_or_field = get_args (annotation_or_field )[0 ]
357+
358+ # check union types
359+ if get_origin (annotation_or_field ) in UNION_TYPES :
360+ for arg in get_args (annotation_or_field ):
361+ annotation_or_field = _detect_collection_fields (
362+ arg , attr , path , result
363+ )
364+ else :
365+ annotation_or_field = _detect_collection_fields (
366+ annotation_or_field , attr , path , result
367+ )
368+ return result
349369
350- annotation_or_field = getattr (
351- annotation_or_field , "outer_type_" , annotation_or_field
352- )
353370
354- # if hasattr(annotation_or_field, "annotation"):
355- annotation_or_field = annotation_or_field .annotation
371+ def _detect_collection_fields (
372+ annotation_or_field : Any ,
373+ attr : str ,
374+ path : Tuple [str , ...],
375+ result : List [Any ],
376+ ) -> Any :
377+ annotation_or_field = next (
378+ (a for a in annotation_or_field .model_fields .values () if a .alias == attr ),
379+ annotation_or_field .model_fields .get (attr ),
380+ ) # pragma: no cover
356381
357- if is_collection_type (annotation_or_field ):
358- result .append (path [- 1 ])
359- return result
382+ annotation_or_field = getattr (
383+ annotation_or_field , "outer_type_" , annotation_or_field
384+ )
385+
386+ # if hasattr(annotation_or_field, "annotation"):
387+ annotation_or_field = annotation_or_field .annotation
388+
389+ if is_collection_type (annotation_or_field ):
390+ return result .append (path [- 1 ])
391+ return annotation_or_field
0 commit comments