@@ -455,14 +455,27 @@ def validate_inputs_for_export(
455455class CoupleInputsDynamicShapes :
456456 """
457457 Pair inputs / dynamic shapes.
458+
459+ :param args: positional arguments
460+ :param kwargs: named arguments
461+ :param dynamic_shapes: dynamic shapes
462+ :param args_names: if both args and kwargs are not empty, then
463+ dynamic shapes must be a dictionary, and positional must be added
464+ to the named arguments. Arguments names or a module must be given
465+ in that case.
458466 """
459467
460468 def __init__ (
461- self , args : Tuple [Any , ...], kwargs : Dict [str , Any ], dynamic_shapes : DYNAMIC_SHAPES
469+ self ,
470+ args : Tuple [Any , ...],
471+ kwargs : Dict [str , Any ],
472+ dynamic_shapes : DYNAMIC_SHAPES ,
473+ args_names : Optional [Union [torch .nn .Module , List [str ]]] = None ,
462474 ):
463475 self .args = args
464476 self .kwargs = kwargs
465477 self .dynamic_shapes = dynamic_shapes
478+ self .args_names = args_names
466479
467480 def __str__ (self ) -> str :
468481 return "\n " .join (
@@ -497,7 +510,39 @@ def invalid_paths(self) -> List[Union[str, int]]:
497510 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
498511 )
499512 return list (self ._valid_shapes (self .args , self .dynamic_shapes ))
500- raise NotImplementedError ("args and kwargs are filled, it is not implemented yet." )
513+
514+ assert isinstance (self .dynamic_shapes , dict ), (
515+ f"Both positional and named arguments (args and kwargs) are filled. "
516+ f"dynamic shapes must a dictionary not { type (self .dynamic_shapes )} "
517+ )
518+ if not self .args_names and set (self .dynamic_shapes ) & set (self .kwargs ) == set (
519+ self .dynamic_shapes
520+ ):
521+ # No dynamic shapes for the positional arguments.
522+ return list (self ._valid_shapes (self .kwargs , self .dynamic_shapes ))
523+
524+ if isinstance (self .args_names , list ):
525+ if not set (self .args_names ) & set (self .dynamic_shapes ):
526+ # No dynamic shapes for the positional arguments.
527+ return list (self ._valid_shapes (self .kwargs , self .dynamic_shapes ))
528+
529+ assert self .args_names , (
530+ "args and kwargs are filled, then args_names must be specified in "
531+ "the constructor to move positional arguments to named arguments."
532+ )
533+ assert len (self .args ) <= len (self .args_names ), (
534+ f"There are { len (self .args )} positional arguments "
535+ f"but only { len (self .args_names )} names. "
536+ f"args={ string_type (self .args , with_shape = True )} , args_name={ self .args_names } "
537+ )
538+ kwargs = dict (zip (self .args_names , self .args ))
539+ kwargs .update (self .kwargs )
540+ return list (self ._valid_shapes (kwargs , self .dynamic_shapes ))
541+
542+ raise NotImplementedError (
543+ f"Not yet implemented when args is filled, "
544+ f"kwargs as well but args_names is { type (self .args_names )} "
545+ )
501546
502547 @classmethod
503548 def _valid_shapes (
0 commit comments