Skip to content

Commit 2d9dddb

Browse files
0cc4myael-works
authored andcommitted
vulkan: vec dot matrix multiplication fix (ggml-org#16151)
* vulkan: fix matrix multiplication index calculation for odd m/n and odd k in combination with batching * add odd m/n + odd k test with batching
1 parent 0c4a4c0 commit 2d9dddb

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,22 @@
3131
#include "types.comp"
3232

3333
#ifndef LOAD_VEC_A
34-
#define LOAD_VEC_A 2
34+
#define LOAD_VEC_A 1
3535
#endif
3636
#ifndef LOAD_VEC_B
37-
#define LOAD_VEC_B 2
37+
#define LOAD_VEC_B 1
38+
#endif
39+
40+
// Load 2 values at once without affecting index calculations through LOAD_VEC
41+
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
42+
#define LOAD_VEC_BATCH_A 2
43+
#else
44+
#define LOAD_VEC_BATCH_A 1
45+
#endif
46+
#if !defined(ALIGNED)
47+
#define LOAD_VEC_BATCH_B 2
48+
#else
49+
#define LOAD_VEC_BATCH_B 1
3850
#endif
3951

4052
#if !defined(TO_FLOAT_TYPE)
@@ -236,13 +248,13 @@ void main() {
236248
const uint warp_r = warp_i % (BM / WM);
237249
const uint warp_c = warp_i / (BM / WM);
238250

239-
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
240-
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
241-
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
242-
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
251+
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
252+
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
253+
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
254+
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
243255

244-
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
245-
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
256+
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
257+
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
246258

247259
#ifdef MUL_MAT_ID
248260
#ifdef MUL_MAT_ID_USE_SUBGROUPS

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
1414
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
1515
buf_a[buf_idx ] = aa.xy;
1616
buf_a[buf_idx + 1] = aa.zw;
17-
#else // LOAD_VEC_A == 2
18-
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
17+
#else // LOAD_VEC_BATCH_A == 2
18+
const uint idx = pos_a + col * p.stride_a + row * 2;
1919
const uint buf_idx = col * SHMEM_STRIDE + row;
2020
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
2121
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
3333
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
3434
buf_a[buf_idx ] = aa.xy;
3535
buf_a[buf_idx + 1] = aa.zw;
36-
#else // LOAD_VEC_A == 2
37-
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
36+
#else // LOAD_VEC_BATCH_A == 2
37+
const uint idx = pos_a + col * p.stride_a + row * 2;
3838
const uint buf_idx = col * SHMEM_STRIDE + row;
3939
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
4040
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
500500
#endif
501501
buf_b[buf_idx + 0] = bb.xy;
502502
buf_b[buf_idx + 1] = bb.zw;
503-
#else // LOAD_VEC_B == 2
504-
const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
503+
#else // LOAD_VEC_BATCH_B == 2
504+
const uint idx = pos_b + col * p.stride_b + row * 2;
505505
const uint buf_idx = col * SHMEM_STRIDE + row;
506506
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
507507
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
536536
#endif
537537
buf_b[buf_idx + 0] = bb.xy;
538538
buf_b[buf_idx + 1] = bb.zw;
539-
#else // LOAD_VEC_B == 2
539+
#else // LOAD_VEC_BATCH_B == 2
540540
const uint row_i = ic * BN + col;
541541
const uint buf_idx = col * SHMEM_STRIDE + row;
542542
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
543543
const u16vec2 row_idx = row_ids[col];
544-
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
544+
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
545545
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
546546
TO_FLOAT_TYPE(data_b[idx + 1]));
547547
} else if (row_i < _ne1 && block + row * 2 < end_k) {
548548
const u16vec2 row_idx = row_ids[col];
549-
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
549+
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
550550
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
551551
} else {
552552
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
454454

455455
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
456456
// For unaligned, load one at a time for f32/f16, or two at a time for quants
457-
std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant;
457+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
458458
// For aligned matmul loads
459459
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
460460

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6231,6 +6231,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62316231
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
62326232
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
62336233
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
6234+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
62346235

62356236
for (auto bs2 : {1,3}) {
62366237
for (auto bs : {1,2,4,8}) {

0 commit comments

Comments
 (0)