Skip to content

Commit 3a1ed62

Browse files
authored
Split _ops and _decomps. (#9323)
1 parent 15366e9 commit 3a1ed62

File tree

8 files changed

+277
-180
lines changed

8 files changed

+277
-180
lines changed

torchax/test/test_libraries.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torchax.export
77
from torchax.ops import jaten
88
from torchax.ops import jlibrary
9-
109
# Create a `mylib` library which has a basic SDPA op.
1110
m = Library("mylib", "DEF")
1211
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")

torchax/test/test_ops.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,26 +143,26 @@ def run_export_and_compare(testcase,
143143
(sample_input.input, sample_input.args, sample_input.kwargs))
144144
with testcase.env:
145145
res2 = func(input2, *args2, **kwargs2)
146-
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
147-
with testcase.subTest("torchax_diff:" + str(atol)):
148-
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
149-
diff_output(
150-
testcase,
151-
res[0],
152-
res2[0],
153-
atol=atol,
154-
rtol=rtol,
155-
equal_nan=equal_nan,
156-
check_output=check_output)
157-
else:
158-
diff_output(
159-
testcase,
160-
res,
161-
res2,
162-
atol=atol,
163-
rtol=rtol,
164-
equal_nan=equal_nan,
165-
check_output=check_output)
146+
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
147+
with testcase.subTest("torchax_diff:" + str(atol)):
148+
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
149+
diff_output(
150+
testcase,
151+
res[0],
152+
res2[0],
153+
atol=atol,
154+
rtol=rtol,
155+
equal_nan=equal_nan,
156+
check_output=check_output)
157+
else:
158+
diff_output(
159+
testcase,
160+
res,
161+
res2,
162+
atol=atol,
163+
rtol=rtol,
164+
equal_nan=equal_nan,
165+
check_output=check_output)
166166

167167

168168
ops_to_test = [
@@ -188,7 +188,6 @@ def setUp(self):
188188
torchax.enable_accuracy_mode()
189189
#self.env.config.debug_accuracy_for_each_op = True
190190
self.env.config.debug_print_each_op = True
191-
self.env.config.debug_print_each_op_operands = True
192191
torch.manual_seed(0)
193192
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
194193
self.env.config.use_torch_native_for_cpu_tensor = False

torchax/test_dist/test_distributed.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# TODO(wcromar): do something useful with group name
1515
GROUP_NAME = "process_group"
1616

17+
torchax.enable_globally()
18+
1719

1820
@pytest.fixture(scope="module")
1921
def multi_cpu():
@@ -80,8 +82,9 @@ def test_all_reduce(op, expected, multi_cpu, process_group):
8082
device_count = multi_cpu
8183

8284
def f(index):
83-
dist.all_reduce(index, op)
84-
return index
85+
with torchax.default_env():
86+
dist.all_reduce(index, op)
87+
return index
8588

8689
res = torchax.distributed.spawn(f)
8790

torchax/torchax/decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def channel_shuffle(self, groups):
131131

132132

133133
def bernoulli_float(self, p=0.5):
134-
return self.bernoulli_(torch.tensor(p))
134+
return self.bernoulli_(p)
135135

136136

137137
_try_register(aten.bernoulli_.float, bernoulli_float)

0 commit comments

Comments
 (0)