Skip to content

Commit 15aee8e

Browse files
bradleyhdpytorchmergebot
authored andcommitted
update aten bmm CK heuristic (pytorch#143294)
Summary: updates heuristic to use new instances based on ck profiling of LLM shapes Differential Revision: D67280269 Pull Request resolved: pytorch#143294 Approved by: https://github.com/mxz297, https://github.com/xw285cornell
1 parent c86383f commit 15aee8e

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,38 +32,28 @@ static const std::unordered_map<
3232

3333
// This is the heursitic to choose a kernel based on inputs
3434
BGEMMKernel_BFloat16 dispatch_bfloat16_bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
35-
// First check if there's a specific kernel for this shape.
35+
// Optional/future use: directly lookup shape tuples to map to instances
36+
/*
3637
auto it = lookup_dispatch.find({m, n, k});
3738
if (it != lookup_dispatch.end()) {
3839
return it->second;
3940
}
41+
*/
42+
43+
// B is A and A is B, so m<-->n
44+
// std::cout << "dispatch_bfloat16_bgemm: m=" << m << " n=" << n << " k=" << k << " num_batches=" << num_batches << " transa=" << transa << " transb=" << transb << std::endl;
4045

41-
// Nout found, use heuristics.
42-
// TN layout, so n, m, k
43-
if (k == 8192) {
44-
if (n <= 4) {
45-
return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1;
46-
}
47-
else if (n <= 32) {
48-
return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v2;
49-
}
50-
else if (n <= 512) {
51-
return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
52-
}
53-
else {
54-
return bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v4;
55-
}
46+
if (m <= 5120) {
47+
if (n <= 4) return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1;
48+
else if (n <= 32) return bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2;
49+
else if (n <= 128) return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3; // <512, <1024, <2048 missing
50+
else if (n <= 4096) return bgemm_kernel_bf16bf16bf16_256_224x256x64_16x16_7x8_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
5651
}
57-
else if (k == 5120) {
58-
if (n <= 32) {
59-
return bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1;
60-
}
61-
else if (n <= 512) {
62-
return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
63-
}
64-
else {
65-
return bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3;
66-
}
52+
else if (m <= 8192) {
53+
if (n <= 8) return bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1; // 3 options available, need to investigate
54+
if (n <= 32) return bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2;
55+
if (n <= 512) return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
56+
if (n <= 4096) return bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3;
6757
}
6858

6959
// Default instance

0 commit comments

Comments
 (0)