Skip to content

Commit 0cb0140

Browse files
authored
[Pipeliner] Enable automatic loop fusion (triton-lang#5726)
This PR turns on automatic loop fusion in the CUDA >= 8.0 pass pipelines. Automatic loop fusion is only enabled for simple loop nests (1 outer loop, 1 inner loop), when the user requests fusion with `tl.range(..., fuse=True)` in the frontend. This PR also rewrites the persistent matmul examples to use loop nests. This is cleaner, but will also enable more powerful and flexible optimizations of loop nests in the future. Primarily, it hides the brittleless of the pipeliner behind a single layer inside the compiler, so ideally the brittleness needs to be dealt with only once and hidden from users. To achieve this, several things have been added to loop fusion: 1. To avoid generating the inner loop inside a conditional, loop nest fusion will "speculate" the length of the inner loop, essentially generating a branch where the inner loop is missing and one where the inner loop is always known to execute at least once. 2. Codegen of the loop induction variables has been slightly altered to better match the expectations of the scheduler, pipeliner(s), and `optimize-accumulator-init`. 3. Codegen of loop iter args has been altered to generate fewer SSA dependencies between the prologue, inner loop, and epilogue, making it more likely for pipelining to be successful. E.g., inner loop iter args that can be initialized outside the loop and reset in the epilogue are done so, rather than in the prologue. Some other things in this PR: * Fixed a bug in the pipeline expander * Added AxisInfo implementation for `ub::PoisonOp` I verified the performance of the rewritten persistent matmul kernels on H100 and Blackwell. Performance of `09-persistent-matmul.py` on H100. Before (2 runs) ``` root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 273.146 4025.362 ROOT ├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_ ├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_ ├─ 283.506 2666.310 cublas [M=8192, N=8192, K=512] │ └─ nan 2666.310 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas ├─ 223.326 307.709 matmul_kernel [M=8192, N=8192, K=512] ├─ 259.293 265.027 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 238.500 288.133 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 258.738 265.594 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] └─ 295.529 232.531 torch [M=8192, N=8192, K=512] └─ nan 232.531 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas Legend (Metric: tflop16/s (inc) Min: 223.33 Max: 295.53) █ 288.31 - 295.53 █ 273.87 - 288.31 █ 259.43 - 273.87 █ 244.99 - 259.43 █ 230.55 - 244.99 █ 223.33 - 230.55 name User code ◀ Only in left graph ▶ Only in right graph root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 273.367 4022.105 ROOT ├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_ ├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_ ├─ 284.284 2659.011 cublas [M=8192, N=8192, K=512] │ └─ nan 2659.011 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas ├─ 221.823 309.795 matmul_kernel [M=8192, N=8192, K=512] ├─ 254.755 269.748 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 240.774 285.411 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 259.109 265.214 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] └─ 295.100 232.868 torch [M=8192, N=8192, K=512] └─ nan 232.868 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas Legend (Metric: tflop16/s (inc) Min: 221.82 Max: 295.10) █ 287.77 - 295.10 █ 273.12 - 287.77 █ 258.46 - 273.12 █ 243.81 - 258.46 █ 229.15 - 243.81 █ 221.82 - 229.15 name User code ◀ Only in left graph ▶ Only in right graph ``` After: ``` root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 274.040 4012.227 ROOT ├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_ ├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_ ├─ 285.369 2648.904 cublas [M=8192, N=8192, K=512] │ └─ nan 2648.904 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas ├─ 217.548 315.881 matmul_kernel [M=8192, N=8192, K=512] ├─ 262.312 261.976 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 244.740 280.785 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 255.113 269.368 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] └─ 292.108 235.253 torch [M=8192, N=8192, K=512] └─ nan 235.253 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas Legend (Metric: tflop16/s (inc) Min: 217.55 Max: 292.11) █ 284.65 - 292.11 █ 269.74 - 284.65 █ 254.83 - 269.74 █ 239.92 - 254.83 █ 225.00 - 239.92 █ 217.55 - 225.00 name User code ◀ Only in left graph ▶ Only in right graph root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 274.997 3998.267 ROOT ├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_ ├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_ ├─ 285.498 2647.706 cublas [M=8192, N=8192, K=512] │ └─ nan 2647.706 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas ├─ 217.884 315.394 matmul_kernel [M=8192, N=8192, K=512] ├─ 262.534 261.755 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 246.617 278.649 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 262.525 261.764 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] └─ 295.007 232.942 torch [M=8192, N=8192, K=512] └─ nan 232.942 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas Legend (Metric: tflop16/s (inc) Min: 217.88 Max: 295.01) █ 287.29 - 295.01 █ 271.87 - 287.29 █ 256.45 - 271.87 █ 241.02 - 256.45 █ 225.60 - 241.02 █ 217.88 - 225.60 name User code ◀ Only in left graph ▶ Only in right graph ```
1 parent 61b5674 commit 0cb0140

File tree

16 files changed

+867
-326
lines changed

16 files changed

+867
-326
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/Analysis/DataFlowFramework.h"
2+
#include "mlir/Dialect/UB/IR/UBOps.h"
23
#include "llvm/Support/Debug.h"
34
#include "llvm/Support/raw_ostream.h"
45

@@ -269,6 +270,28 @@ class ConstantOpAxisInfoVisitor final
269270
}
270271
};
271272

273+
class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
274+
public:
275+
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;
276+
277+
AxisInfo
278+
getAxisInfo(ub::PoisonOp op,
279+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
280+
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
281+
// Poison values are never accessed, thus assume optimistic values.
282+
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
283+
unsigned rank = shape.getRank();
284+
return AxisInfo(
285+
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
286+
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
287+
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
288+
}
289+
290+
return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
291+
/*constancy=*/{1});
292+
}
293+
};
294+
272295
template <typename OpTy>
273296
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
274297
public:
@@ -1012,6 +1035,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10121035
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10131036
CastOpAxisInfoVisitor<triton::BitcastOp>>();
10141037
visitors.append<MakeRangeOpAxisInfoVisitor>();
1038+
visitors.append<PoisonOpAxisInfoVisitor>();
10151039
visitors.append<ConstantOpAxisInfoVisitor>();
10161040
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
10171041
AddSubOpAxisInfoVisitor<arith::AddIOp>,

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <numeric>
55

6+
#include "mlir/Dialect/UB/IR/UBOps.h"
67
#include "mlir/IR/IRMapping.h"
78
#include "mlir/Support/LLVM.h"
89
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -97,16 +98,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
9798

9899
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
99100
triton::TritonDialect, cf::ControlFlowDialect,
100-
scf::SCFDialect>([&](Operation *op) {
101-
bool hasLegalRegions = true;
102-
for (auto &region : op->getRegions()) {
103-
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
104-
}
105-
if (hasLegalRegions && typeConverter.isLegal(op)) {
106-
return true;
107-
}
108-
return false;
109-
});
101+
scf::SCFDialect, ub::UBDialect>(
102+
[&](Operation *op) {
103+
bool hasLegalRegions = true;
104+
for (auto &region : op->getRegions()) {
105+
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
106+
}
107+
if (hasLegalRegions && typeConverter.isLegal(op)) {
108+
return true;
109+
}
110+
return false;
111+
});
110112

111113
// We have requirements for the data layouts
112114
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "mlir/Dialect/Arith/IR/Arith.h"
44
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
5+
#include "mlir/Dialect/UB/IR/UBOps.h"
56
#include "mlir/IR/BuiltinAttributes.h"
67
#include "mlir/Pass/Pass.h"
78
#include "mlir/Transforms/DialectConversion.h"
@@ -859,6 +860,7 @@ class ConvertTritonToTritonGPU
859860
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
860861
populateSCFPatterns(typeConverter, patterns);
861862
populateCFPatterns(typeConverter, patterns);
863+
patterns.insert<GenericOpPattern<ub::PoisonOp>>(typeConverter, context);
862864

863865
auto inti = llvm::APSInt(32, false);
864866
auto i32_ty = IntegerType::get(mod->getContext(), 32);

0 commit comments

Comments
 (0)