@@ -723,13 +723,14 @@ def module_name_type(self):
723723 return f"type({ self .name } )={ self .true_model_name } .{ self .method_name } "
724724
725725 def guess_dynamic_dimensions (
726- self , * tensors , auto : bool = False
726+ self , * tensors , auto : Union [ bool , str ] = False
727727 ) -> Optional [Dict [int , Any ]]:
728728 """
729729 Infers the dynamic dimension from multiple shapes.
730730 If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
731731 which cannot be guessed. Two tensors with the same value for one dimension
732- can be guessed, but if there is only 1, it cannot.
732+ can be guessed, but if there is only 1, it cannot. ``auto``` can be a string
733+ to produce strings.
733734 """
734735 if len (tensors ) == 1 :
735736 if isinstance (tensors [0 ], (int , float )):
@@ -740,7 +741,7 @@ def guess_dynamic_dimensions(
740741 )
741742 return (
742743 {i : torch .export .Dim .AUTO for i in range (len (tensors [0 ].shape ))} # noqa: C420
743- if auto
744+ if auto and not isinstance ( auto , str )
744745 else {}
745746 )
746747 shapes = [t .shape for t in tensors ]
@@ -750,22 +751,26 @@ def guess_dynamic_dimensions(
750751 f"shapes={ shapes } for module { self .name !r} , "
751752 f"class={ self .true_model_name !r} "
752753 )
753- dynamic : Any = torch .export .Dim .DYNAMIC # type: ignore
754+ dynamic : Any = (
755+ auto
756+ if isinstance (auto , str )
757+ else (torch .export .Dim .AUTO if auto else torch .export .Dim .DYNAMIC )
758+ )
754759 rk = set_length .pop ()
755760 res = {}
756761 for i in range (rk ):
757762 set_dim = set (s [i ] for s in shapes )
758763 if len (set_dim ) > 1 :
759- res [i ] = dynamic
764+ res [i ] = dynamic if not isinstance ( dynamic , str ) else f" { dynamic } { i } "
760765 continue
761766 if set_dim == {0 }:
762767 # It is unexpected to find a null dimension. Let's replace it by a dynamic one.
763- res [i ] = dynamic
768+ res [i ] = dynamic if not isinstance ( dynamic , str ) else f" { dynamic } { i } "
764769 continue
765770 return res
766771
767772 def guess_dynamic_shape_object (
768- self , * objs : Any , auto : bool = False , msg : Optional [Callable ] = None
773+ self , * objs : Any , auto : Union [ bool , str ] = False , msg : Optional [Callable ] = None
769774 ) -> Any :
770775 """Guesses the dynamic shapes for one argument."""
771776 if len (objs ) == 0 :
@@ -790,7 +795,11 @@ def guess_dynamic_shape_object(
790795 shapes : Any = []
791796 for i in range (kl .pop ()):
792797 shapes .append (
793- self .guess_dynamic_shape_object (* [o [i ] for o in objs ], auto = auto , msg = msg )
798+ self .guess_dynamic_shape_object (
799+ * [o [i ] for o in objs ],
800+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } t" ,
801+ msg = msg ,
802+ )
794803 )
795804 return tuple (shapes )
796805
@@ -802,7 +811,11 @@ def guess_dynamic_shape_object(
802811 shapes = []
803812 for i in range (kl .pop ()):
804813 shapes .append (
805- self .guess_dynamic_shape_object (* [o [i ] for o in objs ], auto = auto , msg = msg )
814+ self .guess_dynamic_shape_object (
815+ * [o [i ] for o in objs ],
816+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } l" ,
817+ msg = msg ,
818+ )
806819 )
807820 return shapes
808821
@@ -814,7 +827,9 @@ def guess_dynamic_shape_object(
814827 shapes = {}
815828 for i in obj :
816829 shapes [i ] = self .guess_dynamic_shape_object (
817- * [o [i ] for o in objs ], auto = auto , msg = msg
830+ * [o [i ] for o in objs ],
831+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } d" ,
832+ msg = msg ,
818833 )
819834 return shapes
820835
@@ -834,7 +849,9 @@ def guess_dynamic_shape_object(
834849 for i in range (kc .pop ()):
835850 values .append (
836851 self .guess_dynamic_shape_object (
837- * [ca [i ] for ca in col_args ], auto = auto , msg = msg
852+ * [ca [i ] for ca in col_args ],
853+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } o" ,
854+ msg = msg ,
838855 )
839856 )
840857 return values
@@ -852,12 +869,18 @@ def guess_dynamic_shape_object(
852869 key_cache = []
853870 for i in range (kc .pop ()):
854871 key_cache .append (
855- self .guess_dynamic_dimensions (* [o .key_cache [i ] for o in objs ], auto = auto )
872+ self .guess_dynamic_dimensions (
873+ * [o .key_cache [i ] for o in objs ],
874+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } kdc" ,
875+ )
856876 )
857877 value_cache = []
858878 for i in range (vc .pop ()):
859879 value_cache .append (
860- self .guess_dynamic_dimensions (* [o .value_cache [i ] for o in objs ], auto = auto )
880+ self .guess_dynamic_dimensions (
881+ * [o .value_cache [i ] for o in objs ],
882+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } vdc" ,
883+ )
861884 )
862885 return [key_cache , value_cache ]
863886
@@ -867,13 +890,17 @@ def guess_dynamic_shape_object(
867890 f"this object needs serialization function to be registered."
868891 )
869892
870- def guess_dynamic_shapes (self , auto : bool = False ) -> DYNAMIC_SHAPES :
893+ def guess_dynamic_shapes (self , auto : Union [ bool , str ] = False ) -> DYNAMIC_SHAPES :
871894 """
872895 Guesses the dynamic shapes for that module from two execution.
873896 If there is only one execution, then that would be static dimensions.
874897
875898 :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
876- dimension if the number of inputs is one
899+ dimension if the number of inputs is one,
900+ if ``auto`` is a string, it uses strings
901+ :return: guessed dynamic shapes
902+
903+ See example :ref:`l-guess-dynamic-shapes-example`.
877904 """
878905 if len (self .inputs ) == 0 :
879906 # No inputs, unable to guess.
@@ -900,7 +927,9 @@ def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
900927 objs = [_ [0 ][i ] for _ in self .inputs ]
901928 args .append (
902929 self .guess_dynamic_shape_object (
903- * objs , auto = auto , msg = lambda i = i : f" failing input { i } "
930+ * objs ,
931+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } I" ,
932+ msg = lambda i = i : f" failing input { i } " ,
904933 )
905934 )
906935 names = s2 .pop ()
@@ -913,7 +942,9 @@ def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
913942
914943 objs = [_ [1 ][name ] for _ in self .inputs ]
915944 kwargs [name ] = self .guess_dynamic_shape_object (
916- * objs , auto = auto , msg = lambda name = name : f" failing input { name !r} "
945+ * objs ,
946+ auto = auto if isinstance (auto , bool ) else f"{ auto } _{ i } I" ,
947+ msg = lambda name = name : f" failing input { name !r} " ,
917948 )
918949 return tuple (args ), kwargs
919950
0 commit comments