Skip to content

Commit 477618c

Browse files
authored
Fix a sync issue in SM100 MQA logits (deepseek-ai#285)
1 parent 0f5f266 commit 477618c

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
252252
#pragma unroll
253253
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
254254
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
255+
tcgen05_after_thread_sync();
255256
#pragma unroll
256257
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
257258
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
@@ -310,6 +311,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
310311

311312
// Wait UMMA arrival
312313
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
314+
tcgen05_after_thread_sync();
313315

314316
// Release KV empty
315317
empty_kv_barriers[kv_stage_idx]->arrive();
@@ -334,6 +336,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
334336
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
335337
cutlass::arch::fence_view_async_tmem_load();
336338

339+
tcgen05_before_thread_sync();
337340
empty_umma_barriers[warpgroup_idx]->arrive();
338341

339342
#pragma unroll

deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
236236
uint32_t umma_phase = 1;
237237

238238
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
239-
if (q_idx != next_q_idx)
239+
if (q_idx != next_q_idx) {
240240
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
241+
full_q_barriers[q_stage_idx]->wait(q_phase);
242+
}
241243

242244
q_idx = next_q_idx;
243245
kv_idx = next_kv_idx;
@@ -249,6 +251,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
249251
#pragma unroll
250252
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
251253
empty_umma_barriers[i]->wait(umma_phase);
254+
tcgen05_after_thread_sync();
252255
#pragma unroll
253256
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
254257
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
@@ -316,6 +319,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
316319

317320
// Wait UMMA arrival
318321
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
322+
tcgen05_after_thread_sync();
319323
umma_phase ^= 1;
320324

321325
// Release KV empty
@@ -338,6 +342,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
338342
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
339343
cutlass::arch::fence_view_async_tmem_load();
340344

345+
tcgen05_before_thread_sync();
341346
empty_umma_barriers[warpgroup_idx]->arrive();
342347

343348
#pragma unroll

0 commit comments

Comments
 (0)