@@ -133,6 +133,19 @@ def __init__(
133133 self .output_dir = "."
134134 self ._saved_pte_filename = None
135135
136+ def __post_init__ (self ):
137+ """
138+ Post init function to update metadata based on dynamic shape
139+ """
140+ dynamic_shape = self ._get_dynamic_shape ()
141+ if dynamic_shape is not None :
142+ token_dim = dynamic_shape [0 ][1 ]
143+ if self .verbose :
144+ logging .info (
145+ f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: { token_dim .max } "
146+ )
147+ self .metadata ["get_max_seq_len" ] = token_dim .max
148+
136149 def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
137150 """
138151 Set the directory where the .pte file will be saved.
@@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any:
180193 if self .dynamic_shapes :
181194 return self .dynamic_shapes
182195
183- dim = torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )
184196 if self .enable_dynamic_shape :
185197 if not self .use_kv_cache :
186198 # Only one input argument: tokens
187- self .dynamic_shapes = ({1 : dim },)
199+ # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
200+ self .dynamic_shapes = (
201+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
202+ )
188203 else :
189204 # Two input arguments: tokens and input_pos but input_pos is static shape
190- self .dynamic_shapes = ({1 : dim }, {"input_pos" : {0 : 1 }})
205+ self .dynamic_shapes = (
206+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
207+ {"input_pos" : {0 : 1 }},
208+ )
191209 else :
192210 # Two input arguments: tokens and input_pos but both are of static shape
193211 self .dynamic_shapes = None
0 commit comments