@@ -379,8 +379,9 @@ def _generic_walker_step(
379379 return torch .utils ._pytree .tree_unflatten (res , spec )
380380
381381 class ChangeDimensionProcessor :
382- def __init__ (self , desired_values ):
382+ def __init__ (self , desired_values , only_desired ):
383383 self .mapping = desired_values or {}
384+ self .only_desired = only_desired
384385
385386 def _build_new_shape (
386387 self , shape : Tuple [int , ...], ds : Dict [int , Any ]
@@ -397,14 +398,16 @@ def _build_new_shape(
397398 torch .export .dynamic_shapes ._Dim ,
398399 ),
399400 ):
400- d = str ( ds [i ])
401+ d = ds [i ]. __name__
401402 elif not isinstance (ds [i ], int ):
402403 raise NotImplementedError (f"Unable to handle type { ds [i ]} in { ds } " )
403404 if d in self .mapping :
404405 new_dim = self .mapping [d ]
405- else :
406+ elif not self . only_desired :
406407 new_dim = shape [i ] + 1
407408 self .mapping [d ] = new_dim
409+ else :
410+ new_dim = shape [i ]
408411 new_shape [i ] = new_dim
409412 return tuple (new_shape )
410413
@@ -447,7 +450,10 @@ def __call__(self, inputs, ds):
447450 return self ._build_new_tensor (inputs , new_shape )
448451
449452 def change_dynamic_dimensions (
450- self , desired_values : Optional [Dict [str , int ]] = None , args_kwargs : bool = False
453+ self ,
454+ desired_values : Optional [Dict [str , int ]] = None ,
455+ args_kwargs : bool = False ,
456+ only_desired : bool = False ,
451457 ):
452458 """
453459 A model exported with dynamic shapes is not necessarily dynamic
@@ -460,6 +466,8 @@ def change_dynamic_dimensions(
460466
461467 :param desired_values: to fixed named dimension to have the desired value
462468 :param args_kwargs: return both args, kwargs even if empty
469+ :param only_desired: if True, only change the dimension specified in
470+ ``desired_values``
463471 :return: new inputs
464472
465473 Example:
@@ -483,7 +491,8 @@ def change_dynamic_dimensions(
483491 print("-after:", string_type(new_kwargs, with_shape=True))
484492 """
485493 return self ._generic_walker (
486- self .ChangeDimensionProcessor (desired_values ), args_kwargs = args_kwargs
494+ self .ChangeDimensionProcessor (desired_values , only_desired = only_desired ),
495+ args_kwargs = args_kwargs ,
487496 )
488497
489498
0 commit comments