11import inspect
2- from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
2+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
33import numpy as np
44import torch
55from ..helpers import string_type
@@ -488,7 +488,7 @@ def __str__(self) -> str:
488488 ]
489489 )
490490
491- def invalid_paths (self ) -> Any :
491+ def invalid_paths (self ):
492492 """
493493 Tells the inputs are valid based on the dynamic shapes definition.
494494 The method assumes that all custom classes can be serialized.
@@ -501,7 +501,7 @@ def invalid_paths(self) -> Any:
501501 return self ._generic_walker (self ._valid_shapes_tensor )
502502
503503 @classmethod
504- def _valid_shapes_tensor (cls , inputs : Any , ds : Any ) -> Iterable :
504+ def _valid_shapes_tensor (cls , inputs , ds ) :
505505 assert isinstance (inputs , torch .Tensor ), f"unexpected type for inputs { type (inputs )} "
506506 assert isinstance (ds , dict ) and all (isinstance (s , int ) for s in ds ), (
507507 f"Unexpected types, inputs is a Tensor but ds is { ds } , "
@@ -516,7 +516,7 @@ def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable:
516516 issues [i ] = f"d=[{ d } ]"
517517 return issues if issues else None
518518
519- def _generic_walker (self , method_to_call : Callable ) -> Any :
519+ def _generic_walker (self , processor : Callable ):
520520 """
521521 Generic deserializator walking through inputs and dynamic_shapes all along.
522522 The function returns a result with the same structure as the dynamic shapes.
@@ -526,14 +526,14 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
526526 f"Type mismatch, args={ string_type (self .args )} and "
527527 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
528528 )
529- return self ._generic_walker_step (method_to_call , self .kwargs , self .dynamic_shapes )
529+ return self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
530530
531531 if not self .kwargs :
532532 assert isinstance (self .args , tuple ) and isinstance (self .dynamic_shapes , tuple ), (
533533 f"Type mismatch, args={ string_type (self .args )} and "
534534 f"dynamic_shapes={ self .dynamic_shapes } should have the same type."
535535 )
536- return self ._generic_walker_step (method_to_call , self .args , self .dynamic_shapes )
536+ return self ._generic_walker_step (processor , self .args , self .dynamic_shapes )
537537
538538 assert isinstance (self .dynamic_shapes , dict ), (
539539 f"Both positional and named arguments (args and kwargs) are filled. "
@@ -543,14 +543,12 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
543543 self .dynamic_shapes
544544 ):
545545 # No dynamic shapes for the positional arguments.
546- return self ._generic_walker_step (method_to_call , self .kwargs , self .dynamic_shapes )
546+ return self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
547547
548548 if isinstance (self .args_names , list ):
549549 if not set (self .args_names ) & set (self .dynamic_shapes ):
550550 # No dynamic shapes for the positional arguments.
551- return self ._generic_walker_step (
552- method_to_call , self .kwargs , self .dynamic_shapes
553- )
551+ return self ._generic_walker_step (processor , self .kwargs , self .dynamic_shapes )
554552
555553 assert self .args_names , (
556554 "args and kwargs are filled, then args_names must be specified in "
@@ -563,17 +561,17 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
563561 )
564562 kwargs = dict (zip (self .args_names , self .args ))
565563 kwargs .update (self .kwargs )
566- return self ._generic_walker_step (method_to_call , kwargs , self .dynamic_shapes )
564+ return self ._generic_walker_step (processor , kwargs , self .dynamic_shapes )
567565
568566 raise NotImplementedError (
569567 f"Not yet implemented when args is filled, "
570568 f"kwargs as well but args_names is { type (self .args_names )} "
571569 )
572570
573571 @classmethod
574- def _generic_walker_step (cls , method_to_call : Callable , inputs : Any , ds : Any ) -> Iterable :
572+ def _generic_walker_step (cls , processor : Callable , inputs , ds ) :
575573 if isinstance (inputs , torch .Tensor ):
576- return method_to_call (inputs , ds )
574+ return processor (inputs , ds )
577575 if isinstance (inputs , (int , float , str )):
578576 return None
579577 if isinstance (inputs , (tuple , list , dict )):
@@ -588,7 +586,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
588586 if isinstance (inputs , (tuple , list )):
589587 value = []
590588 for i , d in zip (inputs , ds ):
591- value .append (cls ._generic_walker_step (method_to_call , i , d ))
589+ value .append (cls ._generic_walker_step (processor , i , d ))
592590 return (
593591 (value if isinstance (ds , list ) else tuple (value ))
594592 if any (v is not None for v in value )
@@ -599,7 +597,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
599597 ), f"Keys mismatch between inputs { set (inputs )} and ds={ set (ds )} "
600598 dvalue = {}
601599 for k , v in inputs .items ():
602- t = cls ._generic_walker_step (method_to_call , v , ds [k ])
600+ t = cls ._generic_walker_step (processor , v , ds [k ])
603601 if t is not None :
604602 dvalue [k ] = t
605603 return dvalue if dvalue else None
@@ -611,4 +609,4 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
611609 f"map this class with the given dynamic shapes."
612610 )
613611 flat , _spec = torch .utils ._pytree .tree_flatten (inputs )
614- return cls ._generic_walker_step (method_to_call , flat , ds )
612+ return cls ._generic_walker_step (processor , flat , ds )
0 commit comments