Skip to content

Commit cf156c6

Browse files
authored
Remove the clamp op when we do symmetric quantization on a tensor (#9465)
1 parent 52569ec commit cf156c6

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

test/quantized_ops/test_quantized_matmul.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_q_linear_module_per_channel(self, quantize_activation):
123123
x = x.to(device)
124124
out_quant_xla = m(x)
125125
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01))
126-
self.assertTrue(torch.allclose(out_quant_xla.cpu(), out_quant))
126+
self.assertTrue(torch.allclose(out_quant_xla.cpu(), out_quant, atol=2e-3))
127127

128128
@parameterized.parameters([False, True])
129129
def test_q_linear_module_dynamo(self, quantize_activation):
@@ -139,7 +139,8 @@ def test_q_linear_module_dynamo(self, quantize_activation):
139139
m_dynamo = torch.compile(m, backend="openxla")
140140
out_quant_dynamo = m_dynamo(x.to(device))
141141
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.02))
142-
self.assertTrue(torch.allclose(out_quant_dynamo.cpu(), out_quant))
142+
self.assertTrue(
143+
torch.allclose(out_quant_dynamo.cpu(), out_quant, atol=4e-3))
143144

144145
@parameterized.parameters([False, True])
145146
def test_q_linear_hlo(self, quantize_activation):
@@ -240,7 +241,7 @@ def test_blockwise_linear_module(self):
240241
x = x.to(device)
241242
out_quant_xla = m(x)
242243
self.assertGreater(
243-
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)
244+
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.99999)
244245

245246
@parameterized.parameters([False, True])
246247
def test_asymmetric_per_channel(self, quantize_activation):
@@ -263,7 +264,7 @@ def test_asymmetric_per_channel(self, quantize_activation):
263264
x = x.to(device)
264265
out_quant_xla = m(x)
265266
self.assertGreater(
266-
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)
267+
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.99999)
267268

268269
def test_asymmetric_blockwise(self):
269270
for n_bit in [8]:

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@ def _quantize_array(
1212
x_abs_max_val: jax.Array, # [1, bs_block_size]
1313
):
1414
n_bits = 8
15-
int_min = -2**(n_bits - 1)
1615
int_max = 2**(n_bits - 1) - 1
1716
scale = (x_abs_max_val / int_max).T # [bs_block_size, 1]
1817
# Need to explicitly cast to f32 because Mosaic can't directly jnp.round a
1918
# bf16 array.
2019
# It seems x/0 in Pallas generates inf/-inf instead of an exception.
21-
x_int = jnp.clip(
22-
jnp.round((x / scale).astype(jnp.float32)), int_min,
23-
int_max).astype(jnp.int8)
20+
x_int = jnp.round((x / scale).astype(jnp.float32)).astype(jnp.int8)
2421
return x_int, scale.astype(x.dtype)
2522

2623

torch_xla/experimental/xla_quantized_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,9 @@ def _quantize_tensor(x: torch.Tensor, n_bits: int = 8, dim: int = -1):
6767
torch.Tensor: The scaling factor used for quantization. (Same dtype as x)
6868
"""
6969
max_val = torch.amax(torch.abs(x), dim=dim, keepdim=True)
70-
int_min = -2**(n_bits - 1)
7170
int_max = 2**(n_bits - 1) - 1
7271
scale = max_val / int_max
73-
x_int = torch.clamp(torch.round(x / scale), int_min, int_max).to(torch.int8)
72+
x_int = torch.round(x / scale).to(torch.int8)
7473
return x_int, scale.to(x.dtype)
7574

7675

0 commit comments

Comments
 (0)