Skip to content

Commit ebf9a8b

Browse files
Torchax: Allow reuse of the jittable procedure with different functio… (#9374)
Co-authored-by: zmelumian <[email protected]>
1 parent ca04583 commit ebf9a8b

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

torchax/torchax/interop.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,21 @@ def functional_call(self, method_name, params, buffers, *args, **kwargs):
102102
res = getattr(self._model, method_name)(*args, **kwargs)
103103
return res
104104

105-
def forward(self, *args, **kwargs):
106-
if 'forward' not in self._jitted:
105+
def jittable_call(self, method_name: str, *args, **kwargs):
106+
if method_name not in self._jitted:
107107
jitted = jax_jit(
108-
functools.partial(self.functional_call, 'forward'),
108+
functools.partial(self.functional_call, method_name),
109109
kwargs_for_jax_jit=self._extra_jit_args,
110110
)
111111

112112
def jitted_forward(*args, **kwargs):
113113
return jitted(self.params, self.buffers, *args, **kwargs)
114114

115-
self._jitted['forward'] = jitted_forward
116-
return self._jitted['forward'](*args, **kwargs)
115+
self._jitted[method_name] = jitted_forward
116+
return self._jitted[method_name](*args, **kwargs)
117+
118+
def forward(self, *args, **kwargs):
119+
return self.jittable_call('forward', *args, **kwargs)
117120

118121
def __getattr__(self, key):
119122
if key == '_model':

0 commit comments

Comments
 (0)