Skip to content

Commit d4fa9a9

Browse files
committed
[Blackwell] Fix thrown away load due to wrong wait placement
A faulty check line with no -COUNT- was masking bad codegen.
1 parent 58bc6d3 commit d4fa9a9

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
10941094
// CHECK-LABEL: @tensor_memory_ld_red_min_128x256_4_warps
10951095
// CHECK-COUNT-4: tcgen05.ld.red.sync.aligned.32x32b.x64.min.f32
10961096
// CHECK: tcgen05.wait <load>
1097-
// CHECK-3: llvm.intr.minnum
1097+
// CHECK-COUNT-3: llvm.intr.minnum
10981098
tt.func public @tensor_memory_ld_red_min_128x256_4_warps() {
10991099
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked_256N_4w>
11001100
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked_256N_4w>) -> !ttg.memdesc<128x256xf32, #tmem_256N, #ttng.tensor_memory, mutable>
@@ -1105,7 +1105,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
11051105
// CHECK-LABEL: @tensor_memory_ld_red_max_128x256_4_warps
11061106
// CHECK-COUNT-4: tcgen05.ld.red.sync.aligned.32x32b.x64.max.f32
11071107
// CHECK: tcgen05.wait <load>
1108-
// CHECK-3: llvm.intr.maxnum
1108+
// CHECK-COUNT-3: llvm.intr.maxnum
11091109
tt.func public @tensor_memory_ld_red_max_128x256_4_warps() {
11101110
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked_256N_4w>
11111111
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked_256N_4w>) -> !ttg.memdesc<128x256xf32, #tmem_256N, #ttng.tensor_memory, mutable>
@@ -1125,7 +1125,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
11251125
// CHECK-LABEL: @tensor_memory_ld_red_min_128x256_4_warps_nan
11261126
// CHECK-COUNT-4: tcgen05.ld.red.sync.aligned.32x32b.x64.min.NaN.f32
11271127
// CHECK: tcgen05.wait <load>
1128-
// CHECK-3: llvm.intr.minimum
1128+
// CHECK-COUNT-3: llvm.intr.minimum
11291129
tt.func public @tensor_memory_ld_red_min_128x256_4_warps_nan() {
11301130
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked_256N_4w_nan>
11311131
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked_256N_4w_nan>) -> !ttg.memdesc<128x256xf32, #tmem_256N_nan, #ttng.tensor_memory, mutable>
@@ -1136,7 +1136,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
11361136
// CHECK-LABEL: @tensor_memory_ld_red_max_128x256_4_warps_nan
11371137
// CHECK-COUNT-4: tcgen05.ld.red.sync.aligned.32x32b.x64.max.NaN.f32
11381138
// CHECK: tcgen05.wait <load>
1139-
// CHECK-3: llvm.intr.maximum
1139+
// CHECK-COUNT-3: llvm.intr.maximum
11401140
tt.func public @tensor_memory_ld_red_max_128x256_4_warps_nan() {
11411141
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked_256N_4w_nan>
11421142
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked_256N_4w_nan>) -> !ttg.memdesc<128x256xf32, #tmem_256N_nan, #ttng.tensor_memory, mutable>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -412,29 +412,6 @@ std::pair<SmallVector<Value>, SmallVector<Value>> lowerTMemLdSt(
412412
}
413413
}
414414

415-
// Combine partial reductions into one value per thread
416-
if (redvalVals.size() > 1) {
417-
auto isMin = *redOp == TMEMLoadReduceModifier::MIN;
418-
auto applyMinMax = [&](Value lhs, Value rhs) {
419-
return useNaN ? (isMin ? LLVM::MinimumOp::create(rewriter, loc, lhs, rhs)
420-
: LLVM::MaximumOp::create(rewriter, loc, lhs, rhs))
421-
->getResult(0)
422-
: (isMin ? LLVM::MinNumOp::create(rewriter, loc, lhs, rhs)
423-
: LLVM::MaxNumOp::create(rewriter, loc, lhs, rhs))
424-
->getResult(0);
425-
};
426-
// Use tree reduction: pair up elements at each level
427-
while (redvalVals.size() > 1) {
428-
SmallVector<Value> reduced;
429-
assert(redvalVals.size() % 2 == 0 &&
430-
"redvalVals must be a multiple of 2");
431-
for (size_t i = 0; i < redvalVals.size(); i += 2) {
432-
reduced.push_back(applyMinMax(redvalVals[i], redvalVals[i + 1]));
433-
}
434-
redvalVals = std::move(reduced);
435-
}
436-
}
437-
438415
return {resultVals, redvalVals};
439416
}
440417

@@ -519,6 +496,34 @@ static std::pair<SmallVector<Value>, SmallVector<Value>> lowerTMemLdStFromTypes(
519496
vals, tmemBase, redOp, useAbs, useNaN);
520497
}
521498

499+
// Combine partial reductions into one value per thread via tree reduction.
500+
static void combinePartialReductions(Location loc,
501+
ConversionPatternRewriter &rewriter,
502+
SmallVector<Value> &redvalVals,
503+
TMEMLoadReduceModifier redOp,
504+
bool useNaN) {
505+
if (redvalVals.size() <= 1)
506+
return;
507+
auto isMin = redOp == TMEMLoadReduceModifier::MIN;
508+
auto applyMinMax = [&](Value lhs, Value rhs) {
509+
return useNaN ? (isMin ? LLVM::MinimumOp::create(rewriter, loc, lhs, rhs)
510+
: LLVM::MaximumOp::create(rewriter, loc, lhs, rhs))
511+
->getResult(0)
512+
: (isMin ? LLVM::MinNumOp::create(rewriter, loc, lhs, rhs)
513+
: LLVM::MaxNumOp::create(rewriter, loc, lhs, rhs))
514+
->getResult(0);
515+
};
516+
// Use tree reduction: pair up elements at each level
517+
while (redvalVals.size() > 1) {
518+
SmallVector<Value> reduced;
519+
assert(redvalVals.size() % 2 == 0 && "redvalVals must be a multiple of 2");
520+
for (size_t i = 0; i < redvalVals.size(); i += 2) {
521+
reduced.push_back(applyMinMax(redvalVals[i], redvalVals[i + 1]));
522+
}
523+
redvalVals = std::move(reduced);
524+
}
525+
}
526+
522527
struct TensorMemoryLoadOpConversion
523528
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::TMEMLoadOp> {
524529
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -556,6 +561,10 @@ struct TensorMemoryLoadOpConversion
556561
// Wait insertion could be moved to the TTGIR level if needed.
557562
NVVM::Tcgen05WaitOp::create(rewriter, loc, NVVM::Tcgen05WaitKind::LOAD);
558563

564+
// tcgen05.ld.red is async, redval registers aren't valid until the wait
565+
if (redOp)
566+
combinePartialReductions(loc, rewriter, redvalVals, *redOp, useNaN);
567+
559568
// Handle reduction output if present
560569
SmallVector<Value> results = {resultStruct};
561570
if (redOp) {

0 commit comments

Comments
 (0)