13
13
// limitations under the License.
14
14
15
15
#include < iterator>
16
+ #include < string>
16
17
#include < utility>
17
18
18
19
#include " llvm/ADT/STLExtras.h"
23
24
#include " mlir/Dialect/MemRef/IR/MemRef.h"
24
25
#include " mlir/IR/BlockAndValueMapping.h"
25
26
#include " mlir/IR/PatternMatch.h"
27
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
26
28
#include " tfrt/basic_kernels/opdefs/basic_kernels.h"
27
29
#include " tfrt/gpu/kernels/gpu_ops.h"
28
30
#include " tfrt/gpu/passes/passes.h"
@@ -49,20 +51,23 @@ void internal::StreamifyOpConversionSetChain(Value chain,
49
51
50
52
namespace {
51
53
52
- // Wraps consecutive legal ops within a block into a
53
- // tfrt_gpu.streamify op.
54
- struct NestLegalOpsInStreamifyOpPattern
55
- : public OpRewritePattern<func::FuncOp> {
56
- NestLegalOpsInStreamifyOpPattern (MLIRContext *context,
57
- ConversionTarget &target)
58
- : OpRewritePattern(context), target(target) {}
54
+ // Wraps consecutive ops of given names into a tfrt_gpu.streamify op.
55
+ struct StreamifyOpsPattern : public OpRewritePattern <func::FuncOp> {
56
+ StreamifyOpsPattern (MLIRContext *context, ArrayRef<std::string> op_names)
57
+ : OpRewritePattern(context), target(*context) {
58
+ for (const std::string &op_name : op_names) {
59
+ target.setOpAction (OperationName (op_name, context),
60
+ ConversionTarget::LegalizationAction::Illegal);
61
+ }
62
+ }
59
63
60
64
private:
61
65
LogicalResult matchAndRewrite (func::FuncOp func_op,
62
66
PatternRewriter &rewriter) const override ;
63
67
LogicalResult matchAndRewriteBlock (Block *block,
64
68
PatternRewriter &rewriter) const ;
65
- ConversionTarget ⌖
69
+
70
+ ConversionTarget target;
66
71
};
67
72
68
73
// Folds a memref.view of !tfrt_gpu.buffer with zero byte_shift.
@@ -121,13 +126,13 @@ struct ConvertOpTypesPattern : public OpConversionPattern<OpTy> {
121
126
122
127
} // namespace
123
128
124
- LogicalResult NestLegalOpsInStreamifyOpPattern ::matchAndRewrite (
129
+ LogicalResult StreamifyOpsPattern ::matchAndRewrite (
125
130
func::FuncOp func_op, PatternRewriter &rewriter) const {
126
131
rewriter.startRootUpdate (func_op);
127
132
LogicalResult result = failure ();
128
133
func_op.walk ([&](Block *block) {
129
- if (isa<StreamifyOp>(block->getParentOp ())) return WalkResult::skip ();
130
- if ( succeeded (matchAndRewriteBlock (block, rewriter)))
134
+ if (! isa<StreamifyOp>(block->getParentOp ()) &&
135
+ succeeded (matchAndRewriteBlock (block, rewriter)))
131
136
result = success (); // At least one op has been nested.
132
137
return WalkResult::advance ();
133
138
});
@@ -136,24 +141,24 @@ LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewrite(
136
141
return result;
137
142
}
138
143
139
- // Iterate over ops in block, and whenever we transition from a legal to an
140
- // illegal op, wrap preceding legal ops in !tfrt_gpu.streamify.
141
- LogicalResult NestLegalOpsInStreamifyOpPattern ::matchAndRewriteBlock (
144
+ // Iterate over ops in block, and whenever we transition from an illegal to a
145
+ // legal op, wrap preceding illegal ops in !tfrt_gpu.streamify.
146
+ LogicalResult StreamifyOpsPattern ::matchAndRewriteBlock (
142
147
Block *block, PatternRewriter &rewriter) const {
143
148
LogicalResult result = failure ();
144
- Operation *legal_begin = nullptr ;
149
+ Operation *illegal_begin = nullptr ;
145
150
for (Operation *op : llvm::make_pointer_range (block->getOperations ())) {
146
- if (target.isLegal (op)) {
147
- if (!legal_begin ) // Start of legal op sequence.
148
- legal_begin = op;
151
+ if (target.isIllegal (op)) {
152
+ if (!illegal_begin ) // Start of illegal op sequence.
153
+ illegal_begin = op;
149
154
continue ;
150
155
}
151
- if (!legal_begin ) // Continue in illegal op sequence.
156
+ if (!illegal_begin ) // Continue in legal op sequence.
152
157
continue ;
153
158
154
- // Split block before first illegal 'op'.
159
+ // Split block before first legal 'op'.
155
160
Block *epilogue = rewriter.splitBlock (block, op->getIterator ());
156
- auto op_range = make_range (legal_begin ->getIterator (), block->end ());
161
+ auto op_range = make_range (illegal_begin ->getIterator (), block->end ());
157
162
158
163
// Collect results with uses outside of 'block'.
159
164
SmallVector<Value, 4 > results;
@@ -175,17 +180,17 @@ LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewriteBlock(
175
180
}
176
181
177
182
// Create !tfrt_gpu.streamify op with those results.
178
- rewriter.setInsertionPoint (legal_begin );
179
- Location loc = legal_begin ->getLoc ();
183
+ rewriter.setInsertionPoint (illegal_begin );
184
+ Location loc = illegal_begin ->getLoc ();
180
185
auto streamify_op = rewriter.create <StreamifyOp>(loc, results);
181
186
182
- // Move legal ops into !tfrt_gpu.streamify body and merge blocks again.
187
+ // Move illegal ops into !tfrt_gpu.streamify body and merge blocks again.
183
188
Block *body = streamify_op.getBody ();
184
189
body->getOperations ().splice (body->begin (), block->getOperations (),
185
190
op_range.begin (), op_range.end ());
186
191
rewriter.mergeBlocks (epilogue, block, streamify_op->getResults ());
187
192
188
- legal_begin = nullptr ; // Start of illegal op sequence.
193
+ illegal_begin = nullptr ; // Start of legal op sequence.
189
194
result = success ();
190
195
}
191
196
return result;
@@ -275,16 +280,47 @@ LogicalResult ConvertOpTypesPattern<OpTy>::matchAndRewrite(
275
280
return success ();
276
281
}
277
282
278
- void populateStreamifyConversionPatterns (RewritePatternSet &patterns,
279
- TypeConverter &converter,
280
- ConversionTarget &target) {
281
- // Wrap tfrt.call ops to provide chain and stream which may be added to the
282
- // callee's arguments. This adds a chain and stream argument to all functions
283
- // containing such tfrt.call. If this turns out to be a problem, we need to
284
- // analyze the call graph and only wrap calls that execute ops implementing
285
- // the AsyncOpInterface.
286
- target.addLegalOp <compiler::CallOp, compiler::WhileOp, memref::LoadOp>();
283
+ namespace {
284
+
285
+ struct StreamifyOpsPass
286
+ : public mlir::PassWrapper<StreamifyOpsPass, OperationPass<func::FuncOp>> {
287
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (StreamifyOpsPass)
288
+
289
+ StreamifyOpsPass () = default ;
290
+ StreamifyOpsPass (const StreamifyOpsPass &) {}
291
+
292
+ ListOption<std::string> op_names = {*this , " ops" ,
293
+ llvm::cl::desc (" illegal op names" )};
294
+
295
+ private:
296
+ StringRef getArgument () const final { return " tfrt-streamify-ops" ; }
297
+
298
+ void getDependentDialects (DialectRegistry ®istry) const override {
299
+ registry.insert <compiler::TFRTDialect, GpuDialect>();
300
+ }
301
+
302
+ void runOnOperation () override ;
303
+ };
304
+
305
+ } // namespace
306
+
307
+ void StreamifyOpsPass::runOnOperation () {
308
+ RewritePatternSet patterns (&getContext ());
309
+ patterns.add <StreamifyOpsPattern>(&getContext (), op_names);
310
+ if (failed (applyOpPatternsAndFold (getOperation (), std::move (patterns))))
311
+ return signalPassFailure ();
312
+ }
313
+
314
+ std::unique_ptr<OperationPass<func::FuncOp>> CreateStreamifyOpsPass (
315
+ ArrayRef<std::string> op_names) {
316
+ auto pass = std::make_unique<StreamifyOpsPass>();
317
+ std::vector<std::string> &vector = pass->op_names ;
318
+ llvm::copy (op_names, std::back_inserter (vector));
319
+ return pass;
320
+ }
287
321
322
+ void PopulateMemrefConversionPatterns (RewritePatternSet &patterns,
323
+ TypeConverter &converter) {
288
324
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
289
325
converter);
290
326
populateCallOpTypeConversionPattern (patterns, converter);
@@ -294,8 +330,6 @@ void populateStreamifyConversionPatterns(RewritePatternSet &patterns,
294
330
ConvertOpTypesPattern<compiler::WhileOp>>(converter,
295
331
patterns.getContext ());
296
332
297
- patterns.add <NestLegalOpsInStreamifyOpPattern>(patterns.getContext (), target);
298
-
299
333
patterns.add <FoldMemrefViewPattern, FoldMemrefReinterpretCastPattern,
300
334
RewriteMemrefAllocPattern<memref::AllocOp>,
301
335
RewriteMemrefAllocPattern<memref::AllocaOp>>(
0 commit comments