@@ -102,18 +102,21 @@ def functional_call(self, method_name, params, buffers, *args, **kwargs):
102
102
res = getattr (self ._model , method_name )(* args , ** kwargs )
103
103
return res
104
104
105
- def forward (self , * args , ** kwargs ):
106
- if 'forward' not in self ._jitted :
105
+ def jittable_call (self , method_name : str , * args , ** kwargs ):
106
+ if method_name not in self ._jitted :
107
107
jitted = jax_jit (
108
- functools .partial (self .functional_call , 'forward' ),
108
+ functools .partial (self .functional_call , method_name ),
109
109
kwargs_for_jax_jit = self ._extra_jit_args ,
110
110
)
111
111
112
112
def jitted_forward (* args , ** kwargs ):
113
113
return jitted (self .params , self .buffers , * args , ** kwargs )
114
114
115
- self ._jitted ['forward' ] = jitted_forward
116
- return self ._jitted ['forward' ](* args , ** kwargs )
115
+ self ._jitted [method_name ] = jitted_forward
116
+ return self ._jitted [method_name ](* args , ** kwargs )
117
+
118
+ def forward (self , * args , ** kwargs ):
119
+ return self .jittable_call ('forward' , * args , ** kwargs )
117
120
118
121
def __getattr__ (self , key ):
119
122
if key == '_model' :
0 commit comments