@@ -412,29 +412,6 @@ std::pair<SmallVector<Value>, SmallVector<Value>> lowerTMemLdSt(
412412 }
413413 }
414414
415- // Combine partial reductions into one value per thread
416- if (redvalVals.size () > 1 ) {
417- auto isMin = *redOp == TMEMLoadReduceModifier::MIN;
418- auto applyMinMax = [&](Value lhs, Value rhs) {
419- return useNaN ? (isMin ? LLVM::MinimumOp::create (rewriter, loc, lhs, rhs)
420- : LLVM::MaximumOp::create (rewriter, loc, lhs, rhs))
421- ->getResult (0 )
422- : (isMin ? LLVM::MinNumOp::create (rewriter, loc, lhs, rhs)
423- : LLVM::MaxNumOp::create (rewriter, loc, lhs, rhs))
424- ->getResult (0 );
425- };
426- // Use tree reduction: pair up elements at each level
427- while (redvalVals.size () > 1 ) {
428- SmallVector<Value> reduced;
429- assert (redvalVals.size () % 2 == 0 &&
430- " redvalVals must be a multiple of 2" );
431- for (size_t i = 0 ; i < redvalVals.size (); i += 2 ) {
432- reduced.push_back (applyMinMax (redvalVals[i], redvalVals[i + 1 ]));
433- }
434- redvalVals = std::move (reduced);
435- }
436- }
437-
438415 return {resultVals, redvalVals};
439416}
440417
@@ -519,6 +496,34 @@ static std::pair<SmallVector<Value>, SmallVector<Value>> lowerTMemLdStFromTypes(
519496 vals, tmemBase, redOp, useAbs, useNaN);
520497}
521498
499+ // Combine partial reductions into one value per thread via tree reduction.
500+ static void combinePartialReductions (Location loc,
501+ ConversionPatternRewriter &rewriter,
502+ SmallVector<Value> &redvalVals,
503+ TMEMLoadReduceModifier redOp,
504+ bool useNaN) {
505+ if (redvalVals.size () <= 1 )
506+ return ;
507+ auto isMin = redOp == TMEMLoadReduceModifier::MIN;
508+ auto applyMinMax = [&](Value lhs, Value rhs) {
509+ return useNaN ? (isMin ? LLVM::MinimumOp::create (rewriter, loc, lhs, rhs)
510+ : LLVM::MaximumOp::create (rewriter, loc, lhs, rhs))
511+ ->getResult (0 )
512+ : (isMin ? LLVM::MinNumOp::create (rewriter, loc, lhs, rhs)
513+ : LLVM::MaxNumOp::create (rewriter, loc, lhs, rhs))
514+ ->getResult (0 );
515+ };
516+ // Use tree reduction: pair up elements at each level
517+ while (redvalVals.size () > 1 ) {
518+ SmallVector<Value> reduced;
519+ assert (redvalVals.size () % 2 == 0 && " redvalVals must be a multiple of 2" );
520+ for (size_t i = 0 ; i < redvalVals.size (); i += 2 ) {
521+ reduced.push_back (applyMinMax (redvalVals[i], redvalVals[i + 1 ]));
522+ }
523+ redvalVals = std::move (reduced);
524+ }
525+ }
526+
522527struct TensorMemoryLoadOpConversion
523528 : public ConvertOpToLLVMPattern<triton::nvidia_gpu::TMEMLoadOp> {
524529 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -556,6 +561,10 @@ struct TensorMemoryLoadOpConversion
556561 // Wait insertion could be moved to the TTGIR level if needed.
557562 NVVM::Tcgen05WaitOp::create (rewriter, loc, NVVM::Tcgen05WaitKind::LOAD);
558563
564+ // tcgen05.ld.red is async, redval registers aren't valid until the wait
565+ if (redOp)
566+ combinePartialReductions (loc, rewriter, redvalVals, *redOp, useNaN);
567+
559568 // Handle reduction output if present
560569 SmallVector<Value> results = {resultStruct};
561570 if (redOp) {
0 commit comments