@@ -14,69 +14,116 @@ def _quantize_array(
14
14
n_bits = 8
15
15
int_max = 2 ** (n_bits - 1 ) - 1
16
16
scale = (x_abs_max_val / int_max ).T # [bs_block_size, 1]
17
- # Need to explicitly cast to f32 because Mosaic can't directly jnp.round a
18
- # bf16 array.
19
- # It seems x/0 in Pallas generates inf/-inf instead of an exception.
20
- x_int = jnp .round ((x / scale ).astype (jnp .float32 )).astype (jnp .int8 )
21
- return x_int , scale .astype (x .dtype )
17
+ x_int = jnp .round (x / scale ).astype (jnp .int8 )
18
+ return x_int , scale .astype (jnp .float32 )
19
+
20
+
21
+ def unfold_args (args : tuple [jax .Array | bool , ...], fn_args : tuple [bool , ...],
22
+ fn ):
23
+ if len (args ) == 0 :
24
+ fn (* fn_args )
25
+ else :
26
+ arg = args [0 ]
27
+ if isinstance (arg , bool ):
28
+ unfold_args (args [1 :], fn_args + (arg ,), fn )
29
+ else :
30
+ assert arg .dtype == jnp .bool and arg .size == 1
31
+ lax .cond (
32
+ arg ,
33
+ lambda : unfold_args (args [1 :], fn_args + (True ,), fn ),
34
+ lambda : unfold_args (args [1 :], fn_args + (False ,), fn ),
35
+ )
22
36
23
37
24
38
def matmul_kernel (
25
- x_ref , # (batch_block_size, in_block_size)
26
- w_ref , # (out_block_size, in_block_size)
27
- scalar_ref , # (1, out_block_size)
28
- x_abs_max_val , # (1, batch_block_size)
29
- out_ref , # (batch_block_size, out_block_size)
30
- acc_ref , # (batch_block_size, out_block_size)
39
+ x_ref : jax .Array , # (batch_block_size, in_block_size)
40
+ w_ref : jax .Array , # (out_block_size, in_block_size)
41
+ scalar_ref : jax .Array , # (1, out_block_size)
42
+ x_abs_max_ref : jax .Array , # (1, batch_block_size)
43
+ out_ref : jax .Array , # (batch_block_size, out_block_size)
44
+ acc_scratch : jax .Array , # (batch_block_size, out_block_size)
45
+ q_x_scratch : jax .Array , # (batch_block_size, in_block_size)
46
+ x_scale_scratch : jax .Array , # (batch_block_size, 1)
31
47
* ,
32
- quantize_activation ,
33
- batch_block_size ,
34
- out_block_size ,
35
- in_block_size ,
48
+ quantize_activation : bool ,
49
+ save_acc : bool ,
50
+ save_q_x : bool ,
51
+ batch_block_size : int ,
52
+ out_block_size : int ,
53
+ in_block_size : int ,
36
54
):
37
55
bs_idx , out_idx , in_idx = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 )
38
- nsteps = pl .num_programs (2 )
56
+ n_in = pl .num_programs (2 )
39
57
x_ref_dtype = x_ref .dtype
40
58
assert x_ref .shape == (batch_block_size ,
41
59
in_block_size ), "x_ref shape is not correct"
42
60
assert w_ref .shape == (out_block_size ,
43
61
in_block_size ), "w_ref shape is not correct"
44
62
assert scalar_ref .shape == (1 ,
45
63
out_block_size ), "scalar_ref shape is not correct"
46
- assert x_abs_max_val .shape == (
64
+ assert x_abs_max_ref .shape == (
47
65
1 , batch_block_size ), "x_max_val shape is not correct"
48
66
assert out_ref .shape == (batch_block_size ,
49
67
out_block_size ), "out_ref shape is not correct"
50
- assert acc_ref .shape == (batch_block_size ,
51
- out_block_size ), "acc_ref shape is not correct"
52
-
53
- @pl .when (in_idx == 0 )
54
- def _ ():
55
- acc_ref [...] = jnp .zeros_like (acc_ref )
56
-
57
- if quantize_activation :
58
- x , x_scale = _quantize_array (x_ref [...], x_abs_max_val [...])
59
- acc_ref [...] += jax .lax .dot_general (
60
- x ,
61
- w_ref [...],
62
- (((1 ,), (1 ,)), ((), ())),
63
- preferred_element_type = jnp .int32 ,
64
- )
68
+
69
+ if save_q_x :
70
+ assert quantize_activation
71
+ assert q_x_scratch is not None
72
+ assert x_scale_scratch is not None
73
+ quant = out_idx == 0
65
74
else :
66
- acc_ref [...] += jax .lax .dot_general (
67
- x_ref [...],
68
- w_ref [...],
69
- (((1 ,), (1 ,)), ((), ())),
70
- )
71
-
72
- @pl .when (in_idx == nsteps - 1 )
73
- def _ ():
74
- acc = acc_ref [...]
75
- scalar = scalar_ref [...]
76
- acc *= scalar
75
+ assert q_x_scratch is None
76
+ assert x_scale_scratch is None
77
+ quant = quantize_activation
78
+
79
+ if save_acc :
80
+ assert acc_scratch is not None
81
+ is_first_step = in_idx == 0
82
+ is_last_step = in_idx == n_in - 1
83
+ else :
84
+ assert acc_scratch is None
85
+ is_first_step = True
86
+ is_last_step = True
87
+
88
+ def matmul_body (quant , is_first_step , is_last_step ):
77
89
if quantize_activation :
78
- acc *= x_scale
79
- out_ref [...] = acc .astype (x_ref_dtype )
90
+ if quant :
91
+ q_x_tmp , x_scale_tmp = _quantize_array (x_ref [...], x_abs_max_ref [...])
92
+ if save_q_x :
93
+ q_x_scratch [...] = q_x_tmp
94
+ x_scale_scratch [...] = x_scale_tmp
95
+ else :
96
+ assert save_q_x
97
+ q_x_tmp = q_x_scratch [...]
98
+ if is_last_step :
99
+ x_scale_tmp = x_scale_scratch [...]
100
+
101
+ acc = jax .lax .dot_general (
102
+ q_x_tmp ,
103
+ w_ref [...],
104
+ (((1 ,), (1 ,)), ((), ())),
105
+ preferred_element_type = jnp .int32 ,
106
+ )
107
+ else :
108
+ acc = jax .lax .dot_general (
109
+ x_ref [...],
110
+ w_ref [...],
111
+ (((1 ,), (1 ,)), ((), ())),
112
+ )
113
+
114
+ if not is_first_step :
115
+ acc += acc_scratch [...]
116
+
117
+ if is_last_step :
118
+ acc *= scalar_ref [...]
119
+ if quantize_activation :
120
+ acc *= x_scale_tmp
121
+ out_ref [...] = acc .astype (x_ref_dtype )
122
+ else :
123
+ assert save_acc
124
+ acc_scratch [...] = acc
125
+
126
+ unfold_args ((quant , is_first_step , is_last_step ), (), matmul_body )
80
127
81
128
82
129
def _next_multiple (x , multiple ):
@@ -159,10 +206,22 @@ def quantized_matmul_int8(
159
206
# Within the kernel, it will use some extra VMEM for computation or vreg spills.
160
207
vmem_used = vmem_to_be_transferred * 2
161
208
vmem_limit_bytes = min (vmem_used * 2 , 96 * 1024 * 1024 )
209
+
210
+ n_bs = padded_bs // batch_block_size
211
+ n_out = padded_out_features // out_block_size
212
+ n_in = padded_in_features // in_block_size
213
+
214
+ save_acc = n_in > 1
215
+ # Remove redundant input quantization logic by caching quantized input.
216
+ # For best performance, only enable this behavior when single input block is used per batch.
217
+ save_q_x = quantize_activation and n_in == 1 and n_out > 1
218
+
162
219
kernel = pl .pallas_call (
163
220
functools .partial (
164
221
matmul_kernel ,
165
222
quantize_activation = quantize_activation ,
223
+ save_acc = save_acc ,
224
+ save_q_x = save_q_x ,
166
225
batch_block_size = batch_block_size ,
167
226
out_block_size = out_block_size ,
168
227
in_block_size = in_block_size ),
@@ -181,15 +240,18 @@ def quantized_matmul_int8(
181
240
out_specs = pl .BlockSpec ((batch_block_size , out_block_size ),
182
241
lambda b , o , i : (b , o )),
183
242
scratch_shapes = [
184
- pltpu .VMEM ((batch_block_size , out_block_size ), acc_dtype )
243
+ pltpu .VMEM ((batch_block_size ,
244
+ out_block_size ), acc_dtype ) if save_acc else None ,
245
+ pltpu .VMEM ((batch_block_size ,
246
+ in_block_size ), jnp .int8 ) if save_q_x else None ,
247
+ pltpu .VMEM (
248
+ (batch_block_size , 1 ), jnp .float32 ) if save_q_x else None ,
185
249
],
186
- grid = (padded_bs // batch_block_size ,
187
- padded_out_features // out_block_size ,
188
- padded_in_features // in_block_size ),
250
+ grid = (n_bs , n_out , n_in ),
189
251
),
190
252
out_shape = jax .ShapeDtypeStruct ((padded_bs , padded_out_features ), x .dtype ),
191
253
compiler_params = pltpu .TPUCompilerParams (
192
- dimension_semantics = ("parallel" , "parallel " , "arbitrary" ),
254
+ dimension_semantics = ("parallel" , "arbitrary " , "arbitrary" ),
193
255
vmem_limit_bytes = vmem_limit_bytes ,
194
256
),
195
257
)
@@ -217,70 +279,70 @@ def quantized_matmul_int8(
217
279
# - out_block_size
218
280
# - in_block_size
219
281
TUNED_BLOCK_SIZES = {
220
- (6 , 1024 , 1280 , 8192 , 'bfloat16' , True ): (1024 , 1280 , 2048 ),
221
- (6 , 1024 , 28672 , 4096 , 'bfloat16' , True ): (1024 , 3584 , 4096 ),
222
- (6 , 1024 , 4096 , 14336 , 'bfloat16' , True ): (1024 , 4096 , 2048 ),
223
- (6 , 1024 , 4096 , 4096 , 'bfloat16' , True ): (1024 , 1024 , 4096 ),
224
- (6 , 1024 , 6144 , 4096 , 'bfloat16' , True ): (1024 , 1536 , 4096 ),
225
- (6 , 1024 , 7168 , 8192 , 'bfloat16' , True ): (1024 , 1792 , 8192 ),
226
- (6 , 1024 , 8192 , 1024 , 'bfloat16' , True ): (256 , 8192 , 1024 ),
227
- (6 , 1024 , 8192 , 3584 , 'bfloat16' , True ): (1024 , 2048 , 3584 ),
228
- (6 , 128 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
229
- (6 , 128 , 28672 , 4096 , 'bfloat16' , True ): (128 , 1024 , 4096 ),
230
- (6 , 128 , 4096 , 14336 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
231
- (6 , 128 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
282
+ (6 , 128 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
283
+ (6 , 128 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
284
+ (6 , 2048 , 6144 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
285
+ (6 , 2048 , 4096 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
286
+ (6 , 2048 , 4096 , 14336 , 'bfloat16' , True ): (2048 , 4096 , 512 ),
232
287
(6 , 128 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
233
- (6 , 128 , 7168 , 8192 , 'bfloat16' , True ): (128 , 896 , 4096 ),
234
- (6 , 128 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
235
- (6 , 128 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
236
- (6 , 16 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
237
- (6 , 16 , 28672 , 4096 , 'bfloat16' , True ): (128 , 1024 , 4096 ),
238
- (6 , 16 , 4096 , 14336 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
239
- (6 , 16 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
288
+ (6 , 128 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
289
+ (6 , 2048 , 28672 , 4096 , 'bfloat16' , True ): (2048 , 1024 , 4096 ),
240
290
(6 , 16 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
241
- (6 , 16 , 7168 , 8192 , 'bfloat16' , True ): (128 , 896 , 4096 ),
291
+ (6 , 16 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
292
+ (6 , 64 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
293
+ (6 , 64 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
294
+ (6 , 256 , 6144 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
295
+ (6 , 256 , 4096 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
296
+ (6 , 256 , 28672 , 4096 , 'bfloat16' , True ): (256 , 2048 , 4096 ),
297
+ (6 , 256 , 4096 , 14336 , 'bfloat16' , True ): (256 , 4096 , 512 ),
298
+ (6 , 16 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
299
+ (6 , 512 , 6144 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
300
+ (6 , 512 , 4096 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
301
+ (6 , 512 , 28672 , 4096 , 'bfloat16' , True ): (512 , 2048 , 4096 ),
302
+ (6 , 512 , 4096 , 14336 , 'bfloat16' , True ): (512 , 256 , 14336 ),
303
+ (6 , 1024 , 6144 , 4096 , 'bfloat16' , True ): (1024 , 768 , 4096 ),
304
+ (6 , 1024 , 4096 , 4096 , 'bfloat16' , True ): (1024 , 512 , 4096 ),
305
+ (6 , 1024 , 28672 , 4096 , 'bfloat16' , True ): (1024 , 2048 , 4096 ),
306
+ (6 , 1024 , 4096 , 14336 , 'bfloat16' , True ): (1024 , 256 , 14336 ),
307
+ (6 , 16 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
308
+ (6 , 32 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
309
+ (6 , 32 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
310
+ (6 , 32 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
311
+ (6 , 32 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
312
+ (6 , 64 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
313
+ (6 , 64 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
314
+ (6 , 16 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
242
315
(6 , 16 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
243
- (6 , 16 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
244
- (6 , 2048 , 1280 , 8192 , 'bfloat16' , True ): (512 , 1280 , 8192 ),
245
- (6 , 2048 , 28672 , 4096 , 'bfloat16' , True ): (1024 , 4096 , 4096 ),
246
- (6 , 2048 , 4096 , 14336 , 'bfloat16' , True ): (1024 , 4096 , 2048 ),
247
- (6 , 2048 , 4096 , 4096 , 'bfloat16' , True ): (1024 , 2048 , 4096 ),
248
- (6 , 2048 , 6144 , 4096 , 'bfloat16' , True ): (1024 , 3072 , 4096 ),
249
- (6 , 2048 , 7168 , 8192 , 'bfloat16' , True ): (1024 , 1792 , 8192 ),
316
+ (6 , 64 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
317
+ (6 , 64 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
318
+ (6 , 128 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
319
+ (6 , 128 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
320
+ (6 , 128 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
321
+ (6 , 128 , 8192 , 3584 , 'bfloat16' , True ): (128 , 8192 , 512 ),
322
+ (6 , 256 , 1280 , 8192 , 'bfloat16' , True ): (256 , 256 , 8192 ),
323
+ (6 , 256 , 8192 , 1024 , 'bfloat16' , True ): (256 , 2048 , 1024 ),
324
+ (6 , 256 , 7168 , 8192 , 'bfloat16' , True ): (256 , 512 , 8192 ),
325
+ (6 , 256 , 8192 , 3584 , 'bfloat16' , True ): (256 , 8192 , 512 ),
326
+ (6 , 16 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
327
+ (6 , 512 , 1280 , 8192 , 'bfloat16' , True ): (512 , 256 , 8192 ),
328
+ (6 , 512 , 8192 , 1024 , 'bfloat16' , True ): (512 , 4096 , 1024 ),
329
+ (6 , 512 , 7168 , 8192 , 'bfloat16' , True ): (512 , 512 , 8192 ),
330
+ (6 , 512 , 8192 , 3584 , 'bfloat16' , True ): (512 , 2048 , 3584 ),
331
+ (6 , 1024 , 1280 , 8192 , 'bfloat16' , True ): (1024 , 256 , 8192 ),
332
+ (6 , 1024 , 8192 , 1024 , 'bfloat16' , True ): (1024 , 4096 , 1024 ),
333
+ (6 , 1024 , 7168 , 8192 , 'bfloat16' , True ): (1024 , 512 , 8192 ),
334
+ (6 , 1024 , 8192 , 3584 , 'bfloat16' , True ): (1024 , 1024 , 3584 ),
335
+ (6 , 2048 , 1280 , 8192 , 'bfloat16' , True ): (2048 , 256 , 8192 ),
250
336
(6 , 2048 , 8192 , 1024 , 'bfloat16' , True ): (256 , 8192 , 1024 ),
251
- (6 , 2048 , 8192 , 3584 , 'bfloat16' , True ): (1024 , 2048 , 3584 ),
252
- (6 , 256 , 1280 , 8192 , 'bfloat16' , True ): (256 , 1280 , 2048 ),
253
- (6 , 256 , 28672 , 4096 , 'bfloat16' , True ): (256 , 1792 , 4096 ),
254
- (6 , 256 , 4096 , 14336 , 'bfloat16' , True ): (256 , 1024 , 3584 ),
255
- (6 , 256 , 4096 , 4096 , 'bfloat16' , True ): (256 , 1024 , 4096 ),
256
- (6 , 256 , 6144 , 4096 , 'bfloat16' , True ): (256 , 1024 , 4096 ),
257
- (6 , 256 , 7168 , 8192 , 'bfloat16' , True ): (256 , 1024 , 4096 ),
258
- (6 , 256 , 8192 , 1024 , 'bfloat16' , True ): (256 , 4096 , 1024 ),
259
- (6 , 256 , 8192 , 3584 , 'bfloat16' , True ): (256 , 1024 , 3584 ),
260
- (6 , 32 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
261
- (6 , 32 , 28672 , 4096 , 'bfloat16' , True ): (128 , 1024 , 4096 ),
262
- (6 , 32 , 4096 , 14336 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
263
- (6 , 32 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
264
- (6 , 32 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
265
- (6 , 32 , 7168 , 8192 , 'bfloat16' , True ): (128 , 896 , 4096 ),
337
+ (6 , 16 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
338
+ (6 , 2048 , 7168 , 8192 , 'bfloat16' , True ): (2048 , 256 , 8192 ),
339
+ (6 , 2048 , 8192 , 3584 , 'bfloat16' , True ): (2048 , 512 , 3584 ),
340
+ (6 , 32 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
266
341
(6 , 32 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
342
+ (6 , 32 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
267
343
(6 , 32 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
268
- (6 , 512 , 1280 , 8192 , 'bfloat16' , True ): (512 , 1280 , 2048 ),
269
- (6 , 512 , 28672 , 4096 , 'bfloat16' , True ): (512 , 3584 , 4096 ),
270
- (6 , 512 , 4096 , 14336 , 'bfloat16' , True ): (512 , 4096 , 1792 ),
271
- (6 , 512 , 4096 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
272
- (6 , 512 , 6144 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
273
- (6 , 512 , 7168 , 8192 , 'bfloat16' , True ): (512 , 1024 , 8192 ),
274
- (6 , 512 , 8192 , 1024 , 'bfloat16' , True ): (512 , 4096 , 1024 ),
275
- (6 , 512 , 8192 , 3584 , 'bfloat16' , True ): (512 , 2048 , 3584 ),
276
- (6 , 64 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
277
- (6 , 64 , 28672 , 4096 , 'bfloat16' , True ): (128 , 1024 , 4096 ),
278
- (6 , 64 , 4096 , 14336 , 'bfloat16' , True ): (128 , 512 , 7168 ),
279
- (6 , 64 , 4096 , 4096 , 'bfloat16' , True ): (128 , 1024 , 4096 ),
280
- (6 , 64 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
281
- (6 , 64 , 7168 , 8192 , 'bfloat16' , True ): (128 , 896 , 4096 ),
344
+ (6 , 64 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
282
345
(6 , 64 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
283
- (6 , 64 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
284
346
}
285
347
286
348
0 commit comments