Skip to content

Commit 199a9bd

Browse files
authored
Add w8a8 kernel blocks for Qwen 2.5 7B (#9517)
1 parent 7aa466e commit 199a9bd

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,12 @@ def quantized_matmul_int8(
328328
(6, 1024, 13824, 5120, 'bfloat16', True): (1024, 768, 5120),
329329
(6, 1024, 1792, 5120, 'bfloat16', True): (1024, 256, 5120),
330330
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096),
331+
(6, 1024, 3584, 18944, 'bfloat16', True): (1024, 3584, 512),
332+
(6, 1024, 3584, 3584, 'bfloat16', True): (1024, 512, 3584),
333+
(6, 1024, 37888, 3584, 'bfloat16', True): (1024, 1024, 3584),
331334
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336),
332335
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096),
336+
(6, 1024, 4608, 3584, 'bfloat16', True): (1024, 768, 3584),
333337
(6, 1024, 5120, 1280, 'bfloat16', True): (1024, 1280, 1280),
334338
(6, 1024, 5120, 3456, 'bfloat16', True): (1024, 1024, 3456),
335339
(6, 1024, 5120, 640, 'bfloat16', True): (256, 5120, 640),
@@ -344,8 +348,12 @@ def quantized_matmul_int8(
344348
(6, 128, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
345349
(6, 128, 1792, 5120, 'bfloat16', True): (128, 1792, 1280),
346350
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
351+
(6, 128, 3584, 18944, 'bfloat16', True): (128, 256, 18944),
352+
(6, 128, 3584, 3584, 'bfloat16', True): (128, 3584, 896),
353+
(6, 128, 37888, 3584, 'bfloat16', True): (128, 1024, 3584),
347354
(6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
348355
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
356+
(6, 128, 4608, 3584, 'bfloat16', True): (128, 768, 3584),
349357
(6, 128, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
350358
(6, 128, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
351359
(6, 128, 5120, 640, 'bfloat16', True): (128, 2560, 640),
@@ -360,8 +368,12 @@ def quantized_matmul_int8(
360368
(6, 16, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
361369
(6, 16, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
362370
(6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
371+
(6, 16, 3584, 18944, 'bfloat16', True): (128, 256, 18944),
372+
(6, 16, 3584, 3584, 'bfloat16', True): (128, 896, 3584),
373+
(6, 16, 37888, 3584, 'bfloat16', True): (128, 1024, 3584),
363374
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
364375
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
376+
(6, 16, 4608, 3584, 'bfloat16', True): (128, 768, 3584),
365377
(6, 16, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
366378
(6, 16, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
367379
(6, 16, 5120, 640, 'bfloat16', True): (128, 2560, 640),
@@ -374,6 +386,10 @@ def quantized_matmul_int8(
374386
(6, 16, 896, 5120, 'bfloat16', True): (128, 896, 2560),
375387
(6, 16384, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
376388
(6, 16384, 1792, 5120, 'bfloat16', True): (1024, 1792, 5120),
389+
(6, 16384, 3584, 18944, 'bfloat16', True): (256, 3584, 18944),
390+
(6, 16384, 3584, 3584, 'bfloat16', True): (512, 3584, 3584),
391+
(6, 16384, 37888, 3584, 'bfloat16', True): (4096, 512, 3584),
392+
(6, 16384, 4608, 3584, 'bfloat16', True): (512, 4608, 3584),
377393
(6, 16384, 5120, 1280, 'bfloat16', True): (512, 5120, 1280),
378394
(6, 16384, 5120, 3456, 'bfloat16', True): (512, 5120, 3456),
379395
(6, 16384, 5120, 640, 'bfloat16', True): (512, 5120, 640),
@@ -384,8 +400,12 @@ def quantized_matmul_int8(
384400
(6, 2048, 13824, 5120, 'bfloat16', True): (2048, 768, 5120),
385401
(6, 2048, 1792, 5120, 'bfloat16', True): (2048, 256, 5120),
386402
(6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096),
403+
(6, 2048, 3584, 18944, 'bfloat16', True): (2048, 3584, 512),
404+
(6, 2048, 3584, 3584, 'bfloat16', True): (2048, 512, 3584),
405+
(6, 2048, 37888, 3584, 'bfloat16', True): (2048, 1024, 3584),
387406
(6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512),
388407
(6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096),
408+
(6, 2048, 4608, 3584, 'bfloat16', True): (2048, 512, 3584),
389409
(6, 2048, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
390410
(6, 2048, 5120, 3456, 'bfloat16', True): (2048, 512, 3456),
391411
(6, 2048, 5120, 640, 'bfloat16', True): (256, 5120, 640),
@@ -400,8 +420,12 @@ def quantized_matmul_int8(
400420
(6, 256, 13824, 5120, 'bfloat16', True): (256, 512, 5120),
401421
(6, 256, 1792, 5120, 'bfloat16', True): (256, 1792, 1280),
402422
(6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096),
423+
(6, 256, 3584, 18944, 'bfloat16', True): (256, 256, 18944),
424+
(6, 256, 3584, 3584, 'bfloat16', True): (256, 896, 3584),
425+
(6, 256, 37888, 3584, 'bfloat16', True): (256, 4736, 896),
403426
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512),
404427
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
428+
(6, 256, 4608, 3584, 'bfloat16', True): (256, 768, 3584),
405429
(6, 256, 5120, 1280, 'bfloat16', True): (256, 2560, 1280),
406430
(6, 256, 5120, 3456, 'bfloat16', True): (256, 1024, 3456),
407431
(6, 256, 5120, 640, 'bfloat16', True): (256, 2560, 640),
@@ -416,8 +440,12 @@ def quantized_matmul_int8(
416440
(6, 32, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
417441
(6, 32, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
418442
(6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
443+
(6, 32, 3584, 18944, 'bfloat16', True): (128, 128, 18944),
444+
(6, 32, 3584, 3584, 'bfloat16', True): (128, 896, 3584),
445+
(6, 32, 37888, 3584, 'bfloat16', True): (128, 1024, 3584),
419446
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
420447
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
448+
(6, 32, 4608, 3584, 'bfloat16', True): (128, 768, 3584),
421449
(6, 32, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
422450
(6, 32, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
423451
(6, 32, 5120, 640, 'bfloat16', True): (128, 2560, 640),
@@ -430,6 +458,10 @@ def quantized_matmul_int8(
430458
(6, 32, 896, 5120, 'bfloat16', True): (128, 896, 2560),
431459
(6, 4096, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
432460
(6, 4096, 1792, 5120, 'bfloat16', True): (512, 1792, 5120),
461+
(6, 4096, 3584, 18944, 'bfloat16', True): (2048, 3584, 512),
462+
(6, 4096, 3584, 3584, 'bfloat16', True): (4096, 256, 3584),
463+
(6, 4096, 37888, 3584, 'bfloat16', True): (4096, 512, 3584),
464+
(6, 4096, 4608, 3584, 'bfloat16', True): (4096, 512, 3584),
433465
(6, 4096, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
434466
(6, 4096, 5120, 3456, 'bfloat16', True): (4096, 512, 3456),
435467
(6, 4096, 5120, 640, 'bfloat16', True): (256, 5120, 640),
@@ -440,8 +472,12 @@ def quantized_matmul_int8(
440472
(6, 512, 13824, 5120, 'bfloat16', True): (512, 13824, 512),
441473
(6, 512, 1792, 5120, 'bfloat16', True): (512, 1792, 1280),
442474
(6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096),
475+
(6, 512, 3584, 18944, 'bfloat16', True): (512, 256, 18944),
476+
(6, 512, 3584, 3584, 'bfloat16', True): (512, 1792, 3584),
477+
(6, 512, 37888, 3584, 'bfloat16', True): (512, 18944, 512),
443478
(6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336),
444479
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
480+
(6, 512, 4608, 3584, 'bfloat16', True): (512, 768, 3584),
445481
(6, 512, 5120, 1280, 'bfloat16', True): (512, 2560, 1280),
446482
(6, 512, 5120, 3456, 'bfloat16', True): (512, 1280, 3456),
447483
(6, 512, 5120, 640, 'bfloat16', True): (512, 2560, 640),
@@ -456,8 +492,12 @@ def quantized_matmul_int8(
456492
(6, 64, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
457493
(6, 64, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
458494
(6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
495+
(6, 64, 3584, 18944, 'bfloat16', True): (128, 256, 18944),
496+
(6, 64, 3584, 3584, 'bfloat16', True): (128, 896, 3584),
497+
(6, 64, 37888, 3584, 'bfloat16', True): (128, 1024, 3584),
459498
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
460499
(6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
500+
(6, 64, 4608, 3584, 'bfloat16', True): (128, 768, 3584),
461501
(6, 64, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
462502
(6, 64, 5120, 3456, 'bfloat16', True): (128, 1024, 3456),
463503
(6, 64, 5120, 640, 'bfloat16', True): (128, 2560, 640),
@@ -470,6 +510,10 @@ def quantized_matmul_int8(
470510
(6, 64, 896, 5120, 'bfloat16', True): (128, 896, 2560),
471511
(6, 8192, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
472512
(6, 8192, 1792, 5120, 'bfloat16', True): (512, 1792, 5120),
513+
(6, 8192, 3584, 18944, 'bfloat16', True): (2048, 3584, 512),
514+
(6, 8192, 3584, 3584, 'bfloat16', True): (4096, 512, 3584),
515+
(6, 8192, 37888, 3584, 'bfloat16', True): (4096, 1024, 3584),
516+
(6, 8192, 4608, 3584, 'bfloat16', True): (4096, 512, 3584),
473517
(6, 8192, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
474518
(6, 8192, 5120, 3456, 'bfloat16', True): (512, 5120, 3456),
475519
(6, 8192, 5120, 640, 'bfloat16', True): (512, 5120, 640),

0 commit comments

Comments
 (0)