@@ -32,38 +32,28 @@ static const std::unordered_map<
3232
3333// This is the heursitic to choose a kernel based on inputs
3434BGEMMKernel_BFloat16 dispatch_bfloat16_bgemm (CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
35- // First check if there's a specific kernel for this shape.
35+ // Optional/future use: directly lookup shape tuples to map to instances
36+ /*
3637 auto it = lookup_dispatch.find({m, n, k});
3738 if (it != lookup_dispatch.end()) {
3839 return it->second;
3940 }
41+ */
42+
43+ // B is A and A is B, so m<-->n
44+ // std::cout << "dispatch_bfloat16_bgemm: m=" << m << " n=" << n << " k=" << k << " num_batches=" << num_batches << " transa=" << transa << " transb=" << transb << std::endl;
4045
41- // Nout found, use heuristics.
42- // TN layout, so n, m, k
43- if (k == 8192 ) {
44- if (n <= 4 ) {
45- return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1;
46- }
47- else if (n <= 32 ) {
48- return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v2;
49- }
50- else if (n <= 512 ) {
51- return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
52- }
53- else {
54- return bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v4;
55- }
46+ if (m <= 5120 ) {
47+ if (n <= 4 ) return bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1;
48+ else if (n <= 32 ) return bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2;
49+ else if (n <= 128 ) return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3; // <512, <1024, <2048 missing
50+ else if (n <= 4096 ) return bgemm_kernel_bf16bf16bf16_256_224x256x64_16x16_7x8_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
5651 }
57- else if (k == 5120 ) {
58- if (n <= 32 ) {
59- return bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1;
60- }
61- else if (n <= 512 ) {
62- return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
63- }
64- else {
65- return bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3;
66- }
52+ else if (m <= 8192 ) {
53+ if (n <= 8 ) return bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1; // 3 options available, need to investigate
54+ if (n <= 32 ) return bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2;
55+ if (n <= 512 ) return bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3;
56+ if (n <= 4096 ) return bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3;
6757 }
6858
6959 // Default instance
0 commit comments