|
1 | 1 | #include "TritonAMDGPUTransforms/Passes.h" |
2 | 2 | #include "mlir/Support/LLVM.h" |
| 3 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
3 | 4 | #include "third_party/amd/include/Analysis/RangeAnalysis.h" |
4 | 5 | #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" |
5 | 6 | #include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" |
@@ -1058,9 +1059,20 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> { |
1058 | 1059 | continue; |
1059 | 1060 | StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), |
1060 | 1061 | globalPrefetch, localPrefetch, useAsyncCopy); |
1061 | | - if (failed(sp.pipelineLoop())) |
1062 | | - continue; |
| 1062 | + (void)sp.pipelineLoop(); |
1063 | 1063 | } |
| 1064 | + |
| 1065 | + DenseMap<Value, SetVector<Operation *>> assumptions = |
| 1066 | + tt::AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); |
| 1067 | + std::shared_ptr solver = createDataFlowSolver(); |
| 1068 | + solver->load<tt::AMD::TritonIntegerRangeAnalysis>(assumptions); |
| 1069 | + if (failed(solver->initializeAndRun(getOperation()))) |
| 1070 | + return signalPassFailure(); |
| 1071 | + |
| 1072 | + ModuleOp mod = getOperation(); |
| 1073 | + RewritePatternSet patterns(&getContext()); |
| 1074 | + tt::AMD::populateFoldTrueCmpIOpPatterns(patterns, solver); |
| 1075 | + (void)applyPatternsGreedily(mod, std::move(patterns)); |
1064 | 1076 | } |
1065 | 1077 | }; |
1066 | 1078 | } // namespace |
|
0 commit comments