@@ -488,7 +488,7 @@ def __str__(self) -> str:
488488 ]
489489 )
490490
491- def invalid_paths (self ) -> List [ Union [ str , int ]] :
491+ def invalid_paths (self ) -> Any :
492492 """
493493 Tells the inputs are valid based on the dynamic shapes definition.
494494 The method assumes that all custom classes can be serialized.
@@ -498,18 +498,42 @@ def invalid_paths(self) -> List[Union[str, int]]:
498498 The function checks that a dynamic dimension does not receive a value
499499 of 0 or 1. It returns a list of invalid path.
500500 """
501+ return self ._generic_walker (self ._valid_shapes_tensor )
502+
503+ @classmethod
504+ def _valid_shapes_tensor (cls , inputs : Any , ds : Any ) -> Iterable :
505+ assert isinstance (inputs , torch .Tensor ), f"unexpected type for inputs { type (inputs )} "
506+ assert isinstance (ds , dict ) and all (isinstance (s , int ) for s in ds ), (
507+ f"Unexpected types, inputs is a Tensor but ds is { ds } , "
508+ f"a dictionary is expected to specify a dimension dimension"
509+ )
510+ issues = {}
511+ for i , d in enumerate (inputs .shape ):
512+ if i in ds and not isinstance (ds [i ], int ):
513+ # dynamic then
514+ if d in {0 , 1 }:
515+ # export issues for sure
516+ issues [i ] = f"d=[{ d } ]"
517+ return issues if issues else None
518+
519+ def _generic_walker (self , method_to_call : Callable ) -> Any :
520+ """
521+ Generic deserializator walking through inputs and dynamic_shapes all along.
522+ The function returns a result with the same structure as the dynamic shapes.
523+ """
501524 if not self .args :
502525 assert isinstance (self .kwargs , dict ) and isinstance (self .dynamic_shapes , dict ), (
503526 f"Type mismatch, args={ string_type (self .args )} and "
504527 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
505528 )
506- return list (self ._valid_shapes (self .kwargs , self .dynamic_shapes ))
529+ return self ._generic_walker_step (method_to_call , self .kwargs , self .dynamic_shapes )
530+
507531 if not self .kwargs :
508532 assert isinstance (self .args , tuple ) and isinstance (self .dynamic_shapes , tuple ), (
509533 f"Type mismatch, args={ string_type (self .args )} and "
510534 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
511535 )
512- return list ( self ._valid_shapes ( self .args , self .dynamic_shapes ) )
536+ return self ._generic_walker_step ( method_to_call , self .args , self .dynamic_shapes )
513537
514538 assert isinstance (self .dynamic_shapes , dict ), (
515539 f"Both positional and named arguments (args and kwargs) are filled. "
@@ -519,12 +543,14 @@ def invalid_paths(self) -> List[Union[str, int]]:
519543 self .dynamic_shapes
520544 ):
521545 # No dynamic shapes for the positional arguments.
522- return list ( self ._valid_shapes ( self .kwargs , self .dynamic_shapes ) )
546+ return self ._generic_walker_step ( method_to_call , self .kwargs , self .dynamic_shapes )
523547
524548 if isinstance (self .args_names , list ):
525549 if not set (self .args_names ) & set (self .dynamic_shapes ):
526550 # No dynamic shapes for the positional arguments.
527- return list (self ._valid_shapes (self .kwargs , self .dynamic_shapes ))
551+ return self ._generic_walker_step (
552+ method_to_call , self .kwargs , self .dynamic_shapes
553+ )
528554
529555 assert self .args_names , (
530556 "args and kwargs are filled, then args_names must be specified in "
@@ -537,62 +563,52 @@ def invalid_paths(self) -> List[Union[str, int]]:
537563 )
538564 kwargs = dict (zip (self .args_names , self .args ))
539565 kwargs .update (self .kwargs )
540- return list ( self ._valid_shapes ( kwargs , self .dynamic_shapes ) )
566+ return self ._generic_walker_step ( method_to_call , kwargs , self .dynamic_shapes )
541567
542568 raise NotImplementedError (
543569 f"Not yet implemented when args is filled, "
544570 f"kwargs as well but args_names is { type (self .args_names )} "
545571 )
546572
547573 @classmethod
548- def _valid_shapes (
549- cls , inputs : Any , ds : Any , prefix : Tuple [Union [int , str ], ...] = ()
550- ) -> Iterable :
551- assert all (isinstance (i , (int , str )) for i in prefix ), f"Unexpected prefix { prefix } "
574+ def _generic_walker_step (cls , method_to_call : Callable , inputs : Any , ds : Any ) -> Iterable :
552575 if isinstance (inputs , torch .Tensor ):
553- assert isinstance (ds , dict ) and all (
554- isinstance (s , int ) for s in ds
555- ), f"Unexpected types, inputs is a Tensor but ds={ ds } , prefix={ prefix } "
556- for i , d in enumerate (inputs .shape ):
557- if i in ds and not isinstance (ds [i ], int ):
558- # dynamic then
559- if d in {0 , 1 }:
560- # export issues for sure
561- yield (* prefix , f"[{ i } ]" )
562- else :
563- if isinstance (inputs , (int , float , str )):
564- pass
565- elif isinstance (inputs , (tuple , list , dict )):
566- assert type (ds ) is type (inputs ), (
567- f"Type mismatch between inputs { type (inputs )} "
568- f"and ds={ type (ds )} , prefix={ prefix !r} "
569- )
570- assert len (ds ) == len (inputs ), (
571- f"Length mismatch between inputs { len (inputs )} "
572- f"and ds={ len (ds )} , prefix={ prefix !r} \n "
573- f"inputs={ string_type (inputs , with_shape = True )} , ds={ ds } "
574- )
575- if isinstance (inputs , (tuple , list )):
576- for ind , (i , d ) in enumerate (zip (inputs , ds )):
577- for path in cls ._valid_shapes (i , d , prefix = (* prefix , ind )):
578- yield path
579- else :
580- assert set (inputs ) == set (ds ), (
581- f"Keys mismatch between inputs { set (inputs )} "
582- f"and ds={ set (ds )} , prefix={ prefix !r} "
583- )
584- for k , v in inputs .items ():
585- for path in cls ._valid_shapes (v , ds [k ], prefix = (* prefix , k )):
586- yield path
587- else :
588- # A custom class.
589- assert inputs .__class__ in torch .utils ._pytree .SUPPORTED_NODES , (
590- f"Class { inputs .__class__ .__name__ !r} was not registered using "
591- f"torch.utils._pytree.register_pytree_node, it is not possible to "
592- f"map this class with the given dynamic shapes."
576+ return method_to_call (inputs , ds )
577+ if isinstance (inputs , (int , float , str )):
578+ return None
579+ if isinstance (inputs , (tuple , list , dict )):
580+ assert type (ds ) is type (
581+ inputs
582+ ), f"Type mismatch between inputs { type (inputs )} and ds={ type (ds )} "
583+ assert len (ds ) == len (inputs ), (
584+ f"Length mismatch between inputs { len (inputs )} "
585+ f"and ds={ len (ds )} \n "
586+ f"inputs={ string_type (inputs , with_shape = True )} , ds={ ds } "
587+ )
588+ if isinstance (inputs , (tuple , list )):
589+ value = []
590+ for i , d in zip (inputs , ds ):
591+ value .append (cls ._generic_walker_step (method_to_call , i , d ))
592+ return (
593+ (value if isinstance (ds , list ) else tuple (value ))
594+ if any (v is not None for v in value )
595+ else None
593596 )
594- flat , _spec = torch .utils ._pytree .tree_flatten (inputs )
595- for path in cls ._valid_shapes (
596- flat , ds , prefix = (* prefix , inputs .__class__ .__name__ )
597- ):
598- yield path
597+ assert set (inputs ) == set (
598+ ds
599+ ), f"Keys mismatch between inputs { set (inputs )} and ds={ set (ds )} "
600+ dvalue = {}
601+ for k , v in inputs .items ():
602+ t = cls ._generic_walker_step (method_to_call , v , ds [k ])
603+ if t is not None :
604+ dvalue [k ] = t
605+ return dvalue if dvalue else None
606+
607+ # A custom class.
608+ assert inputs .__class__ in torch .utils ._pytree .SUPPORTED_NODES , (
609+ f"Class { inputs .__class__ .__name__ !r} was not registered using "
610+ f"torch.utils._pytree.register_pytree_node, it is not possible to "
611+ f"map this class with the given dynamic shapes."
612+ )
613+ flat , _spec = torch .utils ._pytree .tree_flatten (inputs )
614+ return cls ._generic_walker_step (method_to_call , flat , ds )
0 commit comments