Skip to content

Commit 36ff641

Browse files
authored
Optimize w8a8 pallas kernel (#9473)
1 parent ebf9a8b commit 36ff641

File tree

1 file changed

+169
-107
lines changed

1 file changed

+169
-107
lines changed

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 169 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -14,69 +14,116 @@ def _quantize_array(
1414
n_bits = 8
1515
int_max = 2**(n_bits - 1) - 1
1616
scale = (x_abs_max_val / int_max).T # [bs_block_size, 1]
17-
# Need to explicitly cast to f32 because Mosaic can't directly jnp.round a
18-
# bf16 array.
19-
# It seems x/0 in Pallas generates inf/-inf instead of an exception.
20-
x_int = jnp.round((x / scale).astype(jnp.float32)).astype(jnp.int8)
21-
return x_int, scale.astype(x.dtype)
17+
x_int = jnp.round(x / scale).astype(jnp.int8)
18+
return x_int, scale.astype(jnp.float32)
19+
20+
21+
def unfold_args(args: tuple[jax.Array | bool, ...], fn_args: tuple[bool, ...],
22+
fn):
23+
if len(args) == 0:
24+
fn(*fn_args)
25+
else:
26+
arg = args[0]
27+
if isinstance(arg, bool):
28+
unfold_args(args[1:], fn_args + (arg,), fn)
29+
else:
30+
assert arg.dtype == jnp.bool and arg.size == 1
31+
lax.cond(
32+
arg,
33+
lambda: unfold_args(args[1:], fn_args + (True,), fn),
34+
lambda: unfold_args(args[1:], fn_args + (False,), fn),
35+
)
2236

2337

2438
def matmul_kernel(
25-
x_ref, # (batch_block_size, in_block_size)
26-
w_ref, # (out_block_size, in_block_size)
27-
scalar_ref, # (1, out_block_size)
28-
x_abs_max_val, # (1, batch_block_size)
29-
out_ref, # (batch_block_size, out_block_size)
30-
acc_ref, # (batch_block_size, out_block_size)
39+
x_ref: jax.Array, # (batch_block_size, in_block_size)
40+
w_ref: jax.Array, # (out_block_size, in_block_size)
41+
scalar_ref: jax.Array, # (1, out_block_size)
42+
x_abs_max_ref: jax.Array, # (1, batch_block_size)
43+
out_ref: jax.Array, # (batch_block_size, out_block_size)
44+
acc_scratch: jax.Array, # (batch_block_size, out_block_size)
45+
q_x_scratch: jax.Array, # (batch_block_size, in_block_size)
46+
x_scale_scratch: jax.Array, # (batch_block_size, 1)
3147
*,
32-
quantize_activation,
33-
batch_block_size,
34-
out_block_size,
35-
in_block_size,
48+
quantize_activation: bool,
49+
save_acc: bool,
50+
save_q_x: bool,
51+
batch_block_size: int,
52+
out_block_size: int,
53+
in_block_size: int,
3654
):
3755
bs_idx, out_idx, in_idx = pl.program_id(0), pl.program_id(1), pl.program_id(2)
38-
nsteps = pl.num_programs(2)
56+
n_in = pl.num_programs(2)
3957
x_ref_dtype = x_ref.dtype
4058
assert x_ref.shape == (batch_block_size,
4159
in_block_size), "x_ref shape is not correct"
4260
assert w_ref.shape == (out_block_size,
4361
in_block_size), "w_ref shape is not correct"
4462
assert scalar_ref.shape == (1,
4563
out_block_size), "scalar_ref shape is not correct"
46-
assert x_abs_max_val.shape == (
64+
assert x_abs_max_ref.shape == (
4765
1, batch_block_size), "x_max_val shape is not correct"
4866
assert out_ref.shape == (batch_block_size,
4967
out_block_size), "out_ref shape is not correct"
50-
assert acc_ref.shape == (batch_block_size,
51-
out_block_size), "acc_ref shape is not correct"
52-
53-
@pl.when(in_idx == 0)
54-
def _():
55-
acc_ref[...] = jnp.zeros_like(acc_ref)
56-
57-
if quantize_activation:
58-
x, x_scale = _quantize_array(x_ref[...], x_abs_max_val[...])
59-
acc_ref[...] += jax.lax.dot_general(
60-
x,
61-
w_ref[...],
62-
(((1,), (1,)), ((), ())),
63-
preferred_element_type=jnp.int32,
64-
)
68+
69+
if save_q_x:
70+
assert quantize_activation
71+
assert q_x_scratch is not None
72+
assert x_scale_scratch is not None
73+
quant = out_idx == 0
6574
else:
66-
acc_ref[...] += jax.lax.dot_general(
67-
x_ref[...],
68-
w_ref[...],
69-
(((1,), (1,)), ((), ())),
70-
)
71-
72-
@pl.when(in_idx == nsteps - 1)
73-
def _():
74-
acc = acc_ref[...]
75-
scalar = scalar_ref[...]
76-
acc *= scalar
75+
assert q_x_scratch is None
76+
assert x_scale_scratch is None
77+
quant = quantize_activation
78+
79+
if save_acc:
80+
assert acc_scratch is not None
81+
is_first_step = in_idx == 0
82+
is_last_step = in_idx == n_in - 1
83+
else:
84+
assert acc_scratch is None
85+
is_first_step = True
86+
is_last_step = True
87+
88+
def matmul_body(quant, is_first_step, is_last_step):
7789
if quantize_activation:
78-
acc *= x_scale
79-
out_ref[...] = acc.astype(x_ref_dtype)
90+
if quant:
91+
q_x_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...])
92+
if save_q_x:
93+
q_x_scratch[...] = q_x_tmp
94+
x_scale_scratch[...] = x_scale_tmp
95+
else:
96+
assert save_q_x
97+
q_x_tmp = q_x_scratch[...]
98+
if is_last_step:
99+
x_scale_tmp = x_scale_scratch[...]
100+
101+
acc = jax.lax.dot_general(
102+
q_x_tmp,
103+
w_ref[...],
104+
(((1,), (1,)), ((), ())),
105+
preferred_element_type=jnp.int32,
106+
)
107+
else:
108+
acc = jax.lax.dot_general(
109+
x_ref[...],
110+
w_ref[...],
111+
(((1,), (1,)), ((), ())),
112+
)
113+
114+
if not is_first_step:
115+
acc += acc_scratch[...]
116+
117+
if is_last_step:
118+
acc *= scalar_ref[...]
119+
if quantize_activation:
120+
acc *= x_scale_tmp
121+
out_ref[...] = acc.astype(x_ref_dtype)
122+
else:
123+
assert save_acc
124+
acc_scratch[...] = acc
125+
126+
unfold_args((quant, is_first_step, is_last_step), (), matmul_body)
80127

81128

82129
def _next_multiple(x, multiple):
@@ -159,10 +206,22 @@ def quantized_matmul_int8(
159206
# Within the kernel, it will use some extra VMEM for computation or vreg spills.
160207
vmem_used = vmem_to_be_transferred * 2
161208
vmem_limit_bytes = min(vmem_used * 2, 96 * 1024 * 1024)
209+
210+
n_bs = padded_bs // batch_block_size
211+
n_out = padded_out_features // out_block_size
212+
n_in = padded_in_features // in_block_size
213+
214+
save_acc = n_in > 1
215+
# Remove redundant input quantization logic by caching quantized input.
216+
# For best performance, only enable this behavior when single input block is used per batch.
217+
save_q_x = quantize_activation and n_in == 1 and n_out > 1
218+
162219
kernel = pl.pallas_call(
163220
functools.partial(
164221
matmul_kernel,
165222
quantize_activation=quantize_activation,
223+
save_acc=save_acc,
224+
save_q_x=save_q_x,
166225
batch_block_size=batch_block_size,
167226
out_block_size=out_block_size,
168227
in_block_size=in_block_size),
@@ -181,15 +240,18 @@ def quantized_matmul_int8(
181240
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
182241
lambda b, o, i: (b, o)),
183242
scratch_shapes=[
184-
pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
243+
pltpu.VMEM((batch_block_size,
244+
out_block_size), acc_dtype) if save_acc else None,
245+
pltpu.VMEM((batch_block_size,
246+
in_block_size), jnp.int8) if save_q_x else None,
247+
pltpu.VMEM(
248+
(batch_block_size, 1), jnp.float32) if save_q_x else None,
185249
],
186-
grid=(padded_bs // batch_block_size,
187-
padded_out_features // out_block_size,
188-
padded_in_features // in_block_size),
250+
grid=(n_bs, n_out, n_in),
189251
),
190252
out_shape=jax.ShapeDtypeStruct((padded_bs, padded_out_features), x.dtype),
191253
compiler_params=pltpu.TPUCompilerParams(
192-
dimension_semantics=("parallel", "parallel", "arbitrary"),
254+
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
193255
vmem_limit_bytes=vmem_limit_bytes,
194256
),
195257
)
@@ -217,70 +279,70 @@ def quantized_matmul_int8(
217279
# - out_block_size
218280
# - in_block_size
219281
TUNED_BLOCK_SIZES = {
220-
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 1280, 2048),
221-
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 3584, 4096),
222-
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048),
223-
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 1024, 4096),
224-
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 1536, 4096),
225-
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
226-
(6, 1024, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
227-
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584),
228-
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
229-
(6, 128, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
230-
(6, 128, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
231-
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
282+
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
283+
(6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
284+
(6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096),
285+
(6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096),
286+
(6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512),
232287
(6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
233-
(6, 128, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
234-
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
235-
(6, 128, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
236-
(6, 16, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
237-
(6, 16, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
238-
(6, 16, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
239-
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
288+
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
289+
(6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096),
240290
(6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
241-
(6, 16, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
291+
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
292+
(6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
293+
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
294+
(6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096),
295+
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
296+
(6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096),
297+
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512),
298+
(6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
299+
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
300+
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
301+
(6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096),
302+
(6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336),
303+
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096),
304+
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096),
305+
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096),
306+
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336),
307+
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
308+
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
309+
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
310+
(6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
311+
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
312+
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
313+
(6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
314+
(6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
242315
(6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
243-
(6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
244-
(6, 2048, 1280, 8192, 'bfloat16', True): (512, 1280, 8192),
245-
(6, 2048, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096),
246-
(6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048),
247-
(6, 2048, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096),
248-
(6, 2048, 6144, 4096, 'bfloat16', True): (1024, 3072, 4096),
249-
(6, 2048, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
316+
(6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
317+
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
318+
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
319+
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
320+
(6, 128, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
321+
(6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 512),
322+
(6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192),
323+
(6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024),
324+
(6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192),
325+
(6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512),
326+
(6, 16, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
327+
(6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192),
328+
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
329+
(6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192),
330+
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
331+
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192),
332+
(6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024),
333+
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192),
334+
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584),
335+
(6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192),
250336
(6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
251-
(6, 2048, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584),
252-
(6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 2048),
253-
(6, 256, 28672, 4096, 'bfloat16', True): (256, 1792, 4096),
254-
(6, 256, 4096, 14336, 'bfloat16', True): (256, 1024, 3584),
255-
(6, 256, 4096, 4096, 'bfloat16', True): (256, 1024, 4096),
256-
(6, 256, 6144, 4096, 'bfloat16', True): (256, 1024, 4096),
257-
(6, 256, 7168, 8192, 'bfloat16', True): (256, 1024, 4096),
258-
(6, 256, 8192, 1024, 'bfloat16', True): (256, 4096, 1024),
259-
(6, 256, 8192, 3584, 'bfloat16', True): (256, 1024, 3584),
260-
(6, 32, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
261-
(6, 32, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
262-
(6, 32, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
263-
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
264-
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
265-
(6, 32, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
337+
(6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
338+
(6, 2048, 7168, 8192, 'bfloat16', True): (2048, 256, 8192),
339+
(6, 2048, 8192, 3584, 'bfloat16', True): (2048, 512, 3584),
340+
(6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
266341
(6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
342+
(6, 32, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
267343
(6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
268-
(6, 512, 1280, 8192, 'bfloat16', True): (512, 1280, 2048),
269-
(6, 512, 28672, 4096, 'bfloat16', True): (512, 3584, 4096),
270-
(6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 1792),
271-
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
272-
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
273-
(6, 512, 7168, 8192, 'bfloat16', True): (512, 1024, 8192),
274-
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
275-
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
276-
(6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
277-
(6, 64, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
278-
(6, 64, 4096, 14336, 'bfloat16', True): (128, 512, 7168),
279-
(6, 64, 4096, 4096, 'bfloat16', True): (128, 1024, 4096),
280-
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
281-
(6, 64, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
344+
(6, 64, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
282345
(6, 64, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
283-
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
284346
}
285347

286348

0 commit comments

Comments
 (0)