11import pprint
2- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
2+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
33import packaging .version as pv
44import optree
55import torch
@@ -133,7 +133,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
133133 f"registered first"
134134 )
135135 unregister_class_serialization (cls , verbose = verbose )
136- registration_functions [cls ](verbose = verbose )
136+ registration_functions [cls ](verbose = verbose ) # type: ignore[arg-type]
137137 if verbose :
138138 print (f"[_fix_registration] { cls .__name__ } done." )
139139 # To avoid doing it multiple times.
@@ -142,11 +142,11 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
142142 # classes with no registration at all.
143143 done = {}
144144 for k , v in registration_functions .items ():
145- done [k ] = v (verbose = verbose )
145+ done [k ] = v (verbose = verbose ) # type: ignore[arg-type]
146146 return done
147147
148148
149- def serialization_functions (verbose : int = 0 ) -> Dict [type , Union [ Callable [[int ], bool ], int ]]:
149+ def serialization_functions (verbose : int = 0 ) -> Dict [type , Callable [[int ], bool ]]:
150150 """Returns the list of serialization functions."""
151151 transformers_classes = {
152152 DynamicCache : lambda verbose = verbose : register_class_serialization (
0 commit comments