Skip to content

Commit ca91445

Browse files
Add support for callable in torchax.interop.JittableModule.functional_call in the first parameter (#9451)
Co-authored-by: zmelumian <[email protected]>
1 parent a021bf0 commit ca91445

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

torchax/test/test_jittable_module.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,25 @@ def test_isinstance_does_not_mix(self):
3434
assert isinstance(JittableMoreAwesomeModel, EvenMoreAwesomeModel)
3535
assert not isinstance(JittableMoreAwesomeModel, MyAwesomeModel)
3636

37+
def test_functional_call_callable(self):
38+
39+
def outer_function(model, x):
40+
return x + 1
41+
42+
model = MyAwesomeModel()
43+
jittable_module = interop.JittableModule(model)
44+
45+
# Check if the jittable module can be called like a function
46+
input_tensor = torch.randn(1, 3, 224, 224)
47+
expected_output = input_tensor + 1
48+
49+
output = jittable_module.functional_call(outer_function,
50+
jittable_module.params,
51+
jittable_module.buffers,
52+
input_tensor)
53+
54+
assert torch.equal(output, expected_output)
55+
3756

3857
if __name__ == '__main__':
3958
unittest.main()

torchax/torchax/interop.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,26 @@ def __class__(self):
9090
def __call__(self, *args, **kwargs):
9191
return self.forward(*args, **kwargs)
9292

93-
def functional_call(self, method_name, params, buffers, *args, **kwargs):
93+
def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
9494
kwargs = kwargs or {}
9595
params_copy = copy.copy(params)
9696
params_copy.update(buffers)
9797
# reinflate the state dict so there are not any missing keys
9898
for k, v in self._extra_dumped_weights.items():
9999
for new_key in v:
100100
params_copy[new_key] = params_copy[k]
101+
102+
if isinstance(method_or_name, str):
103+
method = getattr(self._model, method_or_name)
104+
else:
105+
if not callable(method_or_name):
106+
raise TypeError(
107+
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
108+
)
109+
method = method_or_name
110+
args = (self._model,) + args
101111
with torch_stateless._reparametrize_module(self._model, params_copy):
102-
res = getattr(self._model, method_name)(*args, **kwargs)
112+
res = method(*args, **kwargs)
103113
return res
104114

105115
def jittable_call(self, method_name: str, *args, **kwargs):

0 commit comments

Comments
 (0)