@@ -90,16 +90,26 @@ def __class__(self):
90
90
def __call__ (self , * args , ** kwargs ):
91
91
return self .forward (* args , ** kwargs )
92
92
93
- def functional_call (self , method_name , params , buffers , * args , ** kwargs ):
93
+ def functional_call (self , method_or_name , params , buffers , * args , ** kwargs ):
94
94
kwargs = kwargs or {}
95
95
params_copy = copy .copy (params )
96
96
params_copy .update (buffers )
97
97
# reinflate the state dict so there are not any missing keys
98
98
for k , v in self ._extra_dumped_weights .items ():
99
99
for new_key in v :
100
100
params_copy [new_key ] = params_copy [k ]
101
+
102
+ if isinstance (method_or_name , str ):
103
+ method = getattr (self ._model , method_or_name )
104
+ else :
105
+ if not callable (method_or_name ):
106
+ raise TypeError (
107
+ f"method_or_name should be a callable or a string, got { type (method_or_name )} "
108
+ )
109
+ method = method_or_name
110
+ args = (self ._model ,) + args
101
111
with torch_stateless ._reparametrize_module (self ._model , params_copy ):
102
- res = getattr ( self . _model , method_name ) (* args , ** kwargs )
112
+ res = method (* args , ** kwargs )
103
113
return res
104
114
105
115
def jittable_call (self , method_name : str , * args , ** kwargs ):
0 commit comments