@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
1414            FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
1515            buf_a[buf_idx    ] = aa.xy;
1616            buf_a[buf_idx + 1] = aa.zw;
17- #else // LOAD_VEC_A  == 2
18-             const uint idx = pos_a * 2  + col * p.stride_a + row * 2;
17+ #else // LOAD_VEC_BATCH_A  == 2
18+             const uint idx = pos_a + col * p.stride_a + row * 2;
1919            const uint buf_idx = col * SHMEM_STRIDE + row;
2020            if (idx_m < p.M && block + row * 2 + 1 < end_k) {
2121                buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
3333            FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
3434            buf_a[buf_idx    ] = aa.xy;
3535            buf_a[buf_idx + 1] = aa.zw;
36- #else // LOAD_VEC_A  == 2
37-             const uint idx = pos_a * 2  + col * p.stride_a + row * 2;
36+ #else // LOAD_VEC_BATCH_A  == 2
37+             const uint idx = pos_a + col * p.stride_a + row * 2;
3838            const uint buf_idx = col * SHMEM_STRIDE + row;
3939            if (idx_m < p.M && block + row * 2 + 1 < end_k) {
4040                buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
500500#endif
501501            buf_b[buf_idx + 0] = bb.xy;
502502            buf_b[buf_idx + 1] = bb.zw;
503- #else // LOAD_VEC_B  == 2
504-             const uint idx = pos_b * 2  + col * p.stride_b + row * 2;
503+ #else // LOAD_VEC_BATCH_B  == 2
504+             const uint idx = pos_b + col * p.stride_b + row * 2;
505505            const uint buf_idx = col * SHMEM_STRIDE + row;
506506            if (idx_n < p.N && block + row * 2 + 1 < end_k) {
507507                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
536536#endif
537537            buf_b[buf_idx + 0] = bb.xy;
538538            buf_b[buf_idx + 1] = bb.zw;
539- #else // LOAD_VEC_B  == 2
539+ #else // LOAD_VEC_BATCH_B  == 2
540540            const uint row_i = ic * BN + col;
541541            const uint buf_idx = col * SHMEM_STRIDE + row;
542542            if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
543543                const u16vec2 row_idx = row_ids[col];
544-                 const uint idx = pos_b * 2  + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
544+                 const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
545545                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
546546                                                 TO_FLOAT_TYPE(data_b[idx + 1]));
547547            } else if (row_i < _ne1 && block + row * 2 < end_k) {
548548                const u16vec2 row_idx = row_ids[col];
549-                 const uint idx = pos_b * 2  + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
549+                 const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
550550                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
551551            } else {
552552                buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
0 commit comments