Skip to content

Commit 16b1202

Browse files
authored
Clean up quantized matmul condition code (#9506)
1 parent ca47198 commit 16b1202

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ def matmul_kernel(
7070
assert quantize_activation
7171
assert q_x_scratch is not None
7272
assert x_scale_scratch is not None
73-
quant = out_idx == 0
73+
quant = (out_idx == 0)
7474
else:
7575
assert q_x_scratch is None
7676
assert x_scale_scratch is None
7777
quant = quantize_activation
7878

7979
if save_acc:
8080
assert acc_scratch is not None
81-
is_first_step = in_idx == 0
82-
is_last_step = in_idx == n_in - 1
81+
is_first_step = (in_idx == 0)
82+
is_last_step = (in_idx == (n_in - 1))
8383
else:
8484
assert acc_scratch is None
8585
is_first_step = True

0 commit comments

Comments
 (0)