11import inspect
2- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
33import numpy as np
44import torch
55from ..helpers import string_type
66
7+ DYNAMIC_SHAPES = Tuple [Tuple [Any , ...], Dict [str , Any ]]
8+
79
810class ModelInputs :
911 """
@@ -218,7 +220,7 @@ def process_inputs(
218220 return new_inputs
219221
220222 @property
221- def true_model_name (self ):
223+ def true_model_name (self ) -> str :
222224 "Returns class name or module name."
223225 return (
224226 self .model .__class__ .__name__
@@ -227,7 +229,7 @@ def true_model_name(self):
227229 )
228230
229231 @property
230- def full_name (self ):
232+ def full_name (self ) -> str :
231233 "Returns a name and class name."
232234 if self .method_name == "forward" :
233235 return f"{ self .name } :{ self .true_model_name } "
@@ -337,9 +339,7 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
337339 f"{ string_type (objs )} { msg () if msg else '' } in { self .module_name_type } "
338340 )
339341
340- def guess_dynamic_shapes (
341- self ,
342- ) -> Tuple [Tuple [Any , ...], Dict [str , Any ]]:
342+ def guess_dynamic_shapes (self ) -> DYNAMIC_SHAPES :
343343 """
344344 Guesses the dynamic shapes for that module from two execution.
345345 If there is only one execution, then that would be static dimensions.
@@ -386,7 +386,7 @@ def move_to_kwargs(
386386 args : Tuple [Any , ...],
387387 kwargs : Dict [str , Any ],
388388 dynamic_shapes : Tuple [Tuple [Any , ...], Dict [str , Any ]],
389- ) -> Tuple [Tuple [Any , ...], Dict [str , Any ], Tuple [ Tuple [ Any , ...], Dict [ str , Any ]] ]:
389+ ) -> Tuple [Tuple [Any , ...], Dict [str , Any ], DYNAMIC_SHAPES ]:
390390 """
391391 Uses the signatures to move positional arguments (args) to named arguments (kwargs)
392392 with the corresponding dynamic shapes.
@@ -434,3 +434,115 @@ def move_to_kwargs(
434434 f"forward_ordered_parameter_names={ self .forward_ordered_parameter_names } "
435435 )
436436 return args , kwargs , (tuple (), kw_dyn )
437+
438+ def validate_inputs_for_export (
439+ self , dynamic_shapes : Optional [DYNAMIC_SHAPES ] = None
440+ ) -> List [List [str ]]:
441+ """
442+ Validates the inputs the class contains for the given dynamic shapes.
443+ If not specified, the dynamic_shapes are guessed.
444+
445+ :param dynamic_shapes: dynamic shapes to validate
446+ :return: a list of lists, every list contains the path the invalid dimension
447+ """
448+ if dynamic_shapes is None :
449+ if len (self .inputs ) == 1 :
450+ return True
451+ dyn_shapes = self .guess_dynamic_shapes ()
452+ return [CoupleInputsDynamicShapes (* i , dyn_shapes ).invalid_paths () for i in self .inputs ]
453+
454+
455+ class CoupleInputsDynamicShapes :
456+ """
457+ Pair inputs / dynamic shapes.
458+ """
459+
460+ def __init__ (
461+ self , args : Tuple [Any , ...], kwargs : Dict [str , Any ], dynamic_shapes : DYNAMIC_SHAPES
462+ ):
463+ self .args = args
464+ self .kwargs = kwargs
465+ self .dynamic_shapes = dynamic_shapes
466+
467+ def __str__ (self ) -> str :
468+ return "\n " .join (
469+ [
470+ f"{ self .__class__ .__name__ } (" ,
471+ f" args={ string_type (self .args , with_shape = True )} ,"
472+ f" kwargs={ string_type (self .kwargs , with_shape = True )} ,"
473+ f" dynamic_shapes={ string_type (self .dynamic_shapes , with_shape = True )} ,"
474+ f")" ,
475+ ]
476+ )
477+
478+ def invalid_paths (self ) -> List [Union [str , int ]]:
479+ """
480+ Tells the inputs are valid based on the dynamic shapes definition.
481+ The method assumes that all custom classes can be serialized.
482+ If some patches were applied to export, they should enabled while
483+ calling this method if the inputs contains such classes.
484+
485+ The function checks that a dynamic dimension does not receive a value
486+ of 0 or 1. It returns a list of invalid path.
487+ """
488+ if not self .args :
489+ assert isinstance (self .kwargs , dict ) and isinstance (self .dynamic_shapes , dict ), (
490+ f"Type mismatch, args={ string_type (self .args )} and "
491+ f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
492+ )
493+ return list (self ._valid_shapes (self .kwargs , self .dynamic_shapes ))
494+ if not self .kwargs :
495+ assert isinstance (self .args , tuple ) and isinstance (self .dynamic_shapes , tuple ), (
496+ f"Type mismatch, args={ string_type (self .args )} and "
497+ f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
498+ )
499+ return list (self ._valid_shapes (self .args , self .dynamic_shapes ))
500+ raise NotImplementedError ("args and kwargs are filled, it is not implemented yet." )
501+
502+ @classmethod
503+ def _valid_shapes (
504+ cls , inputs : Any , ds : Any , prefix : Tuple [Union [int , str ], ...] = ()
505+ ) -> Iterable :
506+ assert all (isinstance (i , (int , str )) for i in prefix ), f"Unexpected prefix { prefix } "
507+ if isinstance (inputs , torch .Tensor ):
508+ assert isinstance (ds , dict ) and all (
509+ isinstance (s , int ) for s in ds
510+ ), f"Unexpected types, inputs is a Tensor but ds={ ds } , prefix={ prefix } "
511+ for i , d in enumerate (inputs .shape ):
512+ if i in ds and not isinstance (ds [i ], int ):
513+ # dynamic then
514+ if d in {0 , 1 }:
515+ # export issues for sure
516+ yield (* prefix , f"[{ i } ]" )
517+ else :
518+ if isinstance (inputs , (int , float , str )):
519+ pass
520+ elif isinstance (inputs , (tuple , list , dict )):
521+ assert type (ds ) is type (inputs ), (
522+ f"Type mismatch between inputs { type (inputs )} "
523+ f"and ds={ type (ds )} , prefix={ prefix !r} "
524+ )
525+ assert len (ds ) == len (inputs ), (
526+ f"Length mismatch between inputs { len (inputs )} "
527+ f"and ds={ len (ds )} , prefix={ prefix !r} \n "
528+ f"inputs={ string_type (inputs , with_shape = True )} , ds={ ds } "
529+ )
530+ if isinstance (inputs , (tuple , list )):
531+ for ind , (i , d ) in enumerate (zip (inputs , ds )):
532+ for path in cls ._valid_shapes (i , d , prefix = (* prefix , ind )):
533+ yield path
534+ else :
535+ assert set (inputs ) == set (ds ), (
536+ f"Keys mismatch between inputs { set (inputs )} "
537+ f"and ds={ set (ds )} , prefix={ prefix !r} "
538+ )
539+ for k , v in inputs .items ():
540+ for path in cls ._valid_shapes (v , ds [k ], prefix = (* prefix , k )):
541+ yield path
542+ else :
543+ # A custom class.
544+ flat , _spec = torch .utils ._pytree .tree_flatten (inputs )
545+ for path in cls ._valid_shapes (
546+ flat , ds , prefix = (* prefix , inputs .__class__ .__name__ )
547+ ):
548+ yield path
0 commit comments