Skip to content

Commit dc78696

Browse files
committed
add fold-true-cmpi pattern to StreamPipeline.cpp
1 parent dd1f8d5 commit dc78696

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "TritonAMDGPUTransforms/Passes.h"
22
#include "mlir/Support/LLVM.h"
3+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
#include "third_party/amd/include/Analysis/RangeAnalysis.h"
45
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
56
#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h"
@@ -1058,9 +1059,20 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10581059
continue;
10591060
StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages),
10601061
globalPrefetch, localPrefetch, useAsyncCopy);
1061-
if (failed(sp.pipelineLoop()))
1062-
continue;
1062+
(void)sp.pipelineLoop();
10631063
}
1064+
1065+
DenseMap<Value, SetVector<Operation *>> assumptions =
1066+
tt::AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation());
1067+
std::shared_ptr solver = createDataFlowSolver();
1068+
solver->load<tt::AMD::TritonIntegerRangeAnalysis>(assumptions);
1069+
if (failed(solver->initializeAndRun(getOperation())))
1070+
return signalPassFailure();
1071+
1072+
ModuleOp mod = getOperation();
1073+
RewritePatternSet patterns(&getContext());
1074+
tt::AMD::populateFoldTrueCmpIOpPatterns(patterns, solver);
1075+
(void)applyPatternsGreedily(mod, std::move(patterns));
10641076
}
10651077
};
10661078
} // namespace

0 commit comments

Comments
 (0)