@@ -92,7 +92,8 @@ def replace_string_by(self, value: Any = None):
9292 return self ._generic_walker (
9393 lambda inputs , ds , value = value : self ._replace_string_dim_tensor (
9494 inputs , ds , value = value
95- )
95+ ),
96+ flatten_unflatten = True ,
9697 )
9798
9899 @classmethod
@@ -135,7 +136,8 @@ def replace_by_string(self):
135136 return self ._generic_walker (
136137 lambda inputs , ds , unique = unique : self ._replace_dim_tensor_by_string (
137138 inputs , ds , unique = unique
138- )
139+ ),
140+ flatten_unflatten = True ,
139141 )
140142
141143 @classmethod
@@ -203,7 +205,7 @@ def invalid_dimensions_for_export(self):
203205 ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
204206 print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
205207 """
206- return self ._generic_walker (self ._valid_shapes_tensor )
208+ return self ._generic_walker (self ._valid_shapes_tensor , flatten_unflatten = True )
207209
208210 @classmethod
209211 def _valid_shapes_tensor (cls , inputs , ds ):
@@ -221,7 +223,9 @@ def _valid_shapes_tensor(cls, inputs, ds):
221223 issues [i ] = f"d=[{ d } ]"
222224 return issues if issues else None
223225
224- def _generic_walker (self , processor : Callable , args_kwargs : bool = False ):
226+ def _generic_walker (
227+ self , processor : Callable , args_kwargs : bool = False , flatten_unflatten : bool = False
228+ ):
225229 """
226230 Generic deserializator walking through inputs and dynamic_shapes all along.
227231 The function returns a result with the same structure as the dynamic shapes.
@@ -231,15 +235,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
231235 f"Type mismatch, args={ string_type (self .args )} and "
232236 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
233237 )
234- res = self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
238+ res = self ._generic_walker_step (
239+ processor ,
240+ self .kwargs ,
241+ self .dynamic_shapes ,
242+ flatten_unflatten = flatten_unflatten ,
243+ )
235244 return (tuple (), res ) if args_kwargs else res
236245
237246 if not self .kwargs :
238247 assert isinstance (self .args , tuple ) and isinstance (self .dynamic_shapes , tuple ), (
239248 f"Type mismatch, args={ string_type (self .args )} and "
240249 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
241250 )
242- res = self ._generic_walker_step (processor , self .args , self .dynamic_shapes )
251+ res = self ._generic_walker_step (
252+ processor , self .args , self .dynamic_shapes , flatten_unflatten = flatten_unflatten
253+ )
243254 return (res , {}) if args_kwargs else res
244255
245256 assert isinstance (self .dynamic_shapes , dict ), (
@@ -250,12 +261,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
250261 self .dynamic_shapes
251262 ):
252263 # No dynamic shapes for the positional arguments.
253- return self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
264+ return self ._generic_walker_step (
265+ processor ,
266+ self .kwargs ,
267+ self .dynamic_shapes ,
268+ flatten_unflatten = flatten_unflatten ,
269+ )
254270
255271 if isinstance (self .args_names , list ):
256272 if not set (self .args_names ) & set (self .dynamic_shapes ):
257273 # No dynamic shapes for the positional arguments.
258- return self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
274+ return self ._generic_walker_step (
275+ processor ,
276+ self .kwargs ,
277+ self .dynamic_shapes ,
278+ flatten_unflatten = flatten_unflatten ,
279+ )
259280
260281 assert self .args_names , (
261282 "args and kwargs are filled, then args_names must be specified in "
@@ -268,7 +289,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
268289 )
269290 kwargs = dict (zip (self .args_names , self .args ))
270291 kwargs .update (self .kwargs )
271- res = self ._generic_walker_step (processor , kwargs , self .dynamic_shapes )
292+ res = self ._generic_walker_step (
293+ processor , kwargs , self .dynamic_shapes , flatten_unflatten = flatten_unflatten
294+ )
272295 if args_kwargs :
273296 pgs = [None for _ in range (len (self .args ))]
274297 kws = {}
@@ -286,7 +309,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
286309 )
287310
288311 @classmethod
289- def _generic_walker_step (cls , processor : Callable , inputs , ds ):
312+ def _generic_walker_step (
313+ cls , processor : Callable , inputs , ds , flatten_unflatten : bool = False
314+ ):
290315 if isinstance (inputs , torch .Tensor ):
291316 return processor (inputs , ds )
292317 if isinstance (inputs , (int , float , str )):
@@ -303,7 +328,11 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
303328 if isinstance (inputs , (tuple , list )):
304329 value = []
305330 for i , d in zip (inputs , ds ):
306- value .append (cls ._generic_walker_step (processor , i , d ))
331+ value .append (
332+ cls ._generic_walker_step (
333+ processor , i , d , flatten_unflatten = flatten_unflatten
334+ )
335+ )
307336 return (
308337 (value if isinstance (ds , list ) else tuple (value ))
309338 if any (v is not None for v in value )
@@ -314,7 +343,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
314343 ), f"Keys mismatch between inputs { set (inputs )} and ds={ set (ds )} "
315344 dvalue = {}
316345 for k , v in inputs .items ():
317- t = cls ._generic_walker_step (processor , v , ds [k ])
346+ t = cls ._generic_walker_step (
347+ processor , v , ds [k ], flatten_unflatten = flatten_unflatten
348+ )
318349 if t is not None :
319350 dvalue [k ] = t
320351 return dvalue if dvalue else None
@@ -325,11 +356,18 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
325356 f"torch.utils._pytree.register_pytree_node, it is not possible to "
326357 f"map this class with the given dynamic shapes."
327358 )
359+ if flatten_unflatten :
360+ flatunflat = flatten_unflatten_for_dynamic_shapes (inputs )
361+ return cls ._generic_walker_step (
362+ processor , flatunflat , ds , flatten_unflatten = flatten_unflatten
363+ )
328364 flat , _spec = torch .utils ._pytree .tree_flatten (inputs )
329365 if all (isinstance (t , torch .Tensor ) for t in flat ):
330366 # We need to flatten dynamic shapes as well
331367 ds = flatten_dynamic_shapes (ds )
332- return cls ._generic_walker_step (processor , flat , ds )
368+ return cls ._generic_walker_step (
369+ processor , flat , ds , flatten_unflatten = flatten_unflatten
370+ )
333371
334372 class ChangeDimensionProcessor :
335373 def __init__ (self , desired_values ):
0 commit comments