@@ -96,7 +96,7 @@ def _infer_dynamic_dimensions(
9696class InputCandidate :
9797 """Represents a consistence set of inputs for the exported method."""
9898
99- def __init__ (self , args : list [Any ], kwargs : dict [str , Any ], cloned : bool ):
99+ def __init__ (self , args : tuple [Any , ... ], kwargs : dict [str , Any ], cloned : bool ):
100100 self .args = args
101101 self .kwargs = kwargs
102102 self .flat_list , self .spec = torch .utils ._pytree .tree_flatten ((args , kwargs ))
@@ -114,7 +114,7 @@ def __init__(self, args: list[Any], kwargs: dict[str, Any], cloned: bool):
114114 )
115115
116116 self .aligned_spec : torch .utils ._pytree .PyTreeSpec | None = None
117- self .aligned_flat_list : list [torch .Tensor | None ] = None
117+ self .aligned_flat_list : list [torch .Tensor | None ] | None = None
118118
119119 def __str__ (self ) -> str :
120120 return (
@@ -152,13 +152,17 @@ def position_to_args_kwargs(self) -> list[int | str]:
152152 """
153153 if self ._position_to_args_kwargs is None :
154154 self .build_mappings ()
155+ # type checking is missing it
156+ assert self ._position_to_args_kwargs is not None
155157 return self ._position_to_args_kwargs
156158
157159 @property
158- def n_tensors_for_args_kwargs (self ) -> list [int | str ]:
160+ def n_tensors_for_args_kwargs (self ) -> dict [int | str , int ]:
159161 """Returns the number of flat tensors in every args or kwargs."""
160162 if self ._n_tensors_for_args_kwargs is None :
161163 self .build_mappings ()
164+ # type checking is missing it
165+ assert self ._n_tensors_for_args_kwargs is not None
162166 return self ._n_tensors_for_args_kwargs
163167
164168 def _set_aligned_flat_list (
@@ -255,9 +259,7 @@ def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
255259 self .outputs_specs .append (spec )
256260 self .flat_outputs .append ([t .clone ().detach () for t in flat_res ])
257261
258- def align_inputs_none_values (
259- self ,
260- ) -> list [list [torch .Tensor ]]:
262+ def align_inputs_none_values (self ):
261263 """Once the best candidate is chosen, this method aligns every set of inputs
262264 on the best candidate, it inserts None at the right position when
263265 optional inputs are not specified. We consider a set of inputs is aligned
@@ -283,16 +285,23 @@ def align_inputs_none_values(
283285 candidate .align_with (self ._best_candidate , self ._captured_inputs )
284286
285287 def infer_dynamic_shapes (
286- self , set_batch_dimension_for : set [int | str ] | None = None
288+ self , set_batch_dimension_for : set [int | str ] | None = None , return_flat : bool = False
287289 ) -> tuple [dict [int , Any ], ...] | dict [str , dict [int , Any ]]:
288290 """Infers dynamic shapes. based on the collected tensors.
289291 Most of the time, models do support a batch dimension
290292 but this batch dimension has the same value for every input sample.
291293 Instead of running inference on new samples, argument `set_batch_dimension_for`
292294 can be used to tell the first dimension is a dynamic dimension for a particular
293295 set of inputs referenced by their name (str) or their position (int).
296+
297+ `return_flat` tells the function to return a flat tuple instead of
298+ nested structured.
294299 """
295300 self .align_inputs_none_values ()
301+ # type checking
302+ assert self ._best_candidate is not None
303+ assert self ._best_candidate .flat_list is not None
304+ assert self ._best_candidate .aligned_flat_list is not None
296305
297306 def _set_batch_dimension (name_or_position ):
298307 if not set_batch_dimension_for :
@@ -309,6 +318,8 @@ def _set_batch_dimension(name_or_position):
309318 return False
310319
311320 def _set_batch_dimension_for_flat_index (index ):
321+ # type checking
322+ assert self ._best_candidate is not None
312323 return _set_batch_dimension (self ._best_candidate .position_to_args_kwargs [index ])
313324
314325 if len (self ._best_candidate .flat_list ) != len (self ._best_candidate .aligned_flat_list ):
@@ -329,6 +340,7 @@ def _set_batch_dimension_for_flat_index(index):
329340 shape_lists = [
330341 [(None if t is None else t .shape ) for t in candidate .aligned_flat_list ]
331342 for candidate in self .inputs
343+ if candidate .aligned_flat_list is not None
332344 ]
333345 n_tensors = len (shape_lists [0 ])
334346 dynamic_shapes = [
@@ -340,6 +352,8 @@ def _set_batch_dimension_for_flat_index(index):
340352 ]
341353 cst = torch .export .Dim .DYNAMIC
342354 flat_dynamic_shapes = [dict .fromkeys (dims , cst ) for dims in dynamic_shapes ]
355+ if return_flat :
356+ return tuple (flat_dynamic_shapes )
343357 if len (flat_dynamic_shapes ) == len (self ._best_candidate .args ) + len (
344358 self ._best_candidate .kwargs
345359 ):
@@ -391,10 +405,9 @@ def infer_arguments(
391405 """Infers arguments based on the collected tensors."""
392406 # This is already checked by _build_inputs_completed_with_none_values
393407 # but this is not always well captured by tools checking types.
394- torch ._check (
395- self ._best_candidate .args is not None and self ._best_candidate .kwargs is not None ,
396- lambda : "No input was captured." ,
397- )
408+ torch ._check (self ._best_candidate is not None , lambda : "No input was captured." )
409+ # type checking
410+ assert self ._best_candidate is not None
398411 candidate = None
399412 if index is None :
400413 for cand in self .inputs :
@@ -412,16 +425,25 @@ def infer_arguments(
412425 candidate = self .inputs [index ]
413426
414427 torch ._check (candidate is not None , "No input was captured." )
428+ # type checking
429+ assert candidate is not None
430+ assert candidate .aligned_flat_list is not None
415431
416432 aligned_flat_list = candidate .aligned_flat_list
417433 if any (t is None for t in aligned_flat_list ):
418- dynamic_shapes = self .infer_dynamic_shapes ()
434+ dynamic_shapes = self .infer_dynamic_shapes (return_flat = True )
435+ # type checking
436+ assert isinstance (dynamic_shapes , tuple )
419437 aligned_flat_list = aligned_flat_list .copy ()
420438 for index in range (len (aligned_flat_list )):
421439 if aligned_flat_list [index ] is not None :
422440 continue
423441 shape = dynamic_shapes [index ]
424- all_non_empty_tensors = [c .aligned_flat_list [index ] for c in self .inputs ]
442+ all_non_empty_tensors = [
443+ c .aligned_flat_list [index ]
444+ for c in self .inputs
445+ if c .aligned_flat_list is not None
446+ ]
425447 all_non_empty_tensors = [t for t in all_non_empty_tensors if t is not None ]
426448 if not all_non_empty_tensors :
427449 raise RuntimeError (
@@ -444,6 +466,9 @@ def infer_arguments(
444466 aligned_flat_list [index ] = torch .empty (
445467 tuple (new_shape ), dtype = tensor .dtype , device = tensor .device
446468 )
469+ # type checking
470+ assert candidate is not None
471+ assert candidate .aligned_spec is not None
447472 args , kwargs = torch .utils ._pytree .tree_unflatten (
448473 aligned_flat_list , candidate .aligned_spec
449474 )
0 commit comments