Skip to content

Commit 2a9138a

Browse files
authored
mul: remove opmath cast sequence (#9663)
Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from `mul`. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage.
1 parent 1ab6787 commit 2a9138a

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/test_operations_hlo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,22 @@ def test_dropout_by_u8_mask(self):
6767
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b])
6868
assert 'u8' in hlo_text
6969

70+
def test_bfloat16_mul_not_upcast(self):
71+
a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
72+
b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
73+
c = a * b
74+
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c])
75+
# Check that the output is not upcasted to float32
76+
assert 'f32' not in hlo_text
77+
78+
def test_bfloat16_float32_mul_upcast(self):
79+
a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
80+
b = torch.rand(5, 5, dtype=torch.float32).to('xla')
81+
c = a * b
82+
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c])
83+
# Check that the output is upcasted to float32
84+
assert 'f32' in hlo_text
85+
7086

7187
if __name__ == '__main__':
7288
torch.set_default_dtype(torch.float32)

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2535,7 +2535,6 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
25352535
.add_input(self)
25362536
.add_input(other)
25372537
.cast_inputs_to_common_dtype()
2538-
.use_opmathtype_for_compute()
25392538
.run();
25402539
}
25412540

0 commit comments

Comments
 (0)