Skip to content

Commit bdb14ff

Browse files
chsiggcopybara-github
authored andcommitted
Split wrapping ops in tfrt_gpu.streamify into a separate pass.
PiperOrigin-RevId: 444750859
1 parent 0a00ce9 commit bdb14ff

File tree

5 files changed

+107
-118
lines changed

5 files changed

+107
-118
lines changed

backends/gpu/include/tfrt/gpu/passes/passes.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
#define TFRT_GPU_PASSES_PASSES_H_
2121

2222
#include <memory>
23+
#include <string>
2324

2425
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/ADT/STLExtras.h"
2627
#include "llvm/ADT/StringRef.h"
28+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2729
#include "mlir/Pass/Pass.h"
2830
#include "mlir/Transforms/DialectConversion.h"
2931

@@ -83,11 +85,23 @@ mlir::StringRef GetGpuModuleAttrName();
8385
// the corresponding unrealized_conversion_cast materializers.
8486
mlir::TypeConverter CreateMemrefToTfrtGpuConverter();
8587

86-
// Adds rewrite patterns that wraps consecutive legal ops as defined by
87-
// `target` into a tfrt_gpu.streamify op.
88-
void populateStreamifyConversionPatterns(mlir::RewritePatternSet& patterns,
89-
mlir::TypeConverter& converter,
90-
mlir::ConversionTarget& target);
88+
// Creates a pass which wraps ops into a tfrt_gpu.streamify op.
89+
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateStreamifyOpsPass(
90+
mlir::ArrayRef<std::string> op_names);
91+
92+
// Creates a pass which wraps the template argument ops into a
93+
// tfrt_gpu.streamify op.
94+
template <typename... OpTs>
95+
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
96+
CreateStreamifyOpsPass() {
97+
std::string op_names[] = {
98+
static_cast<std::string>(OpTs::getOperationName())...};
99+
return CreateStreamifyOpsPass(op_names);
100+
}
101+
102+
// Adds rewrite patterns which convert memref op to tfrt_gpu.
103+
void PopulateMemrefConversionPatterns(mlir::RewritePatternSet& patterns,
104+
mlir::TypeConverter& converter);
91105

92106
// Adds passes to convert from MLIR's gpu and async dialects to TFRT. Adds
93107
// !tfrt.chain result and !tfrt.chain, !tfrt_gpu.stream arguments to functions.

backends/gpu/lib/passes/gpu_async_patterns.cc

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include <iterator>
16+
#include <string>
1617
#include <utility>
1718

1819
#include "llvm/ADT/STLExtras.h"
@@ -23,6 +24,7 @@
2324
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2425
#include "mlir/IR/BlockAndValueMapping.h"
2526
#include "mlir/IR/PatternMatch.h"
27+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2628
#include "tfrt/basic_kernels/opdefs/basic_kernels.h"
2729
#include "tfrt/gpu/kernels/gpu_ops.h"
2830
#include "tfrt/gpu/passes/passes.h"
@@ -49,20 +51,23 @@ void internal::StreamifyOpConversionSetChain(Value chain,
4951

5052
namespace {
5153

52-
// Wraps consecutive legal ops within a block into a
53-
// tfrt_gpu.streamify op.
54-
struct NestLegalOpsInStreamifyOpPattern
55-
: public OpRewritePattern<func::FuncOp> {
56-
NestLegalOpsInStreamifyOpPattern(MLIRContext *context,
57-
ConversionTarget &target)
58-
: OpRewritePattern(context), target(target) {}
54+
// Wraps consecutive ops of given names into a tfrt_gpu.streamify op.
55+
struct StreamifyOpsPattern : public OpRewritePattern<func::FuncOp> {
56+
StreamifyOpsPattern(MLIRContext *context, ArrayRef<std::string> op_names)
57+
: OpRewritePattern(context), target(*context) {
58+
for (const std::string &op_name : op_names) {
59+
target.setOpAction(OperationName(op_name, context),
60+
ConversionTarget::LegalizationAction::Illegal);
61+
}
62+
}
5963

6064
private:
6165
LogicalResult matchAndRewrite(func::FuncOp func_op,
6266
PatternRewriter &rewriter) const override;
6367
LogicalResult matchAndRewriteBlock(Block *block,
6468
PatternRewriter &rewriter) const;
65-
ConversionTarget &target;
69+
70+
ConversionTarget target;
6671
};
6772

6873
// Folds a memref.view of !tfrt_gpu.buffer with zero byte_shift.
@@ -121,13 +126,13 @@ struct ConvertOpTypesPattern : public OpConversionPattern<OpTy> {
121126

122127
} // namespace
123128

124-
LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewrite(
129+
LogicalResult StreamifyOpsPattern::matchAndRewrite(
125130
func::FuncOp func_op, PatternRewriter &rewriter) const {
126131
rewriter.startRootUpdate(func_op);
127132
LogicalResult result = failure();
128133
func_op.walk([&](Block *block) {
129-
if (isa<StreamifyOp>(block->getParentOp())) return WalkResult::skip();
130-
if (succeeded(matchAndRewriteBlock(block, rewriter)))
134+
if (!isa<StreamifyOp>(block->getParentOp()) &&
135+
succeeded(matchAndRewriteBlock(block, rewriter)))
131136
result = success(); // At least one op has been nested.
132137
return WalkResult::advance();
133138
});
@@ -136,24 +141,24 @@ LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewrite(
136141
return result;
137142
}
138143

139-
// Iterate over ops in block, and whenever we transition from a legal to an
140-
// illegal op, wrap preceding legal ops in !tfrt_gpu.streamify.
141-
LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewriteBlock(
144+
// Iterate over ops in block, and whenever we transition from an illegal to a
145+
// legal op, wrap preceding illegal ops in !tfrt_gpu.streamify.
146+
LogicalResult StreamifyOpsPattern::matchAndRewriteBlock(
142147
Block *block, PatternRewriter &rewriter) const {
143148
LogicalResult result = failure();
144-
Operation *legal_begin = nullptr;
149+
Operation *illegal_begin = nullptr;
145150
for (Operation *op : llvm::make_pointer_range(block->getOperations())) {
146-
if (target.isLegal(op)) {
147-
if (!legal_begin) // Start of legal op sequence.
148-
legal_begin = op;
151+
if (target.isIllegal(op)) {
152+
if (!illegal_begin) // Start of illegal op sequence.
153+
illegal_begin = op;
149154
continue;
150155
}
151-
if (!legal_begin) // Continue in illegal op sequence.
156+
if (!illegal_begin) // Continue in legal op sequence.
152157
continue;
153158

154-
// Split block before first illegal 'op'.
159+
// Split block before first legal 'op'.
155160
Block *epilogue = rewriter.splitBlock(block, op->getIterator());
156-
auto op_range = make_range(legal_begin->getIterator(), block->end());
161+
auto op_range = make_range(illegal_begin->getIterator(), block->end());
157162

158163
// Collect results with uses outside of 'block'.
159164
SmallVector<Value, 4> results;
@@ -175,17 +180,17 @@ LogicalResult NestLegalOpsInStreamifyOpPattern::matchAndRewriteBlock(
175180
}
176181

177182
// Create !tfrt_gpu.streamify op with those results.
178-
rewriter.setInsertionPoint(legal_begin);
179-
Location loc = legal_begin->getLoc();
183+
rewriter.setInsertionPoint(illegal_begin);
184+
Location loc = illegal_begin->getLoc();
180185
auto streamify_op = rewriter.create<StreamifyOp>(loc, results);
181186

182-
// Move legal ops into !tfrt_gpu.streamify body and merge blocks again.
187+
// Move illegal ops into !tfrt_gpu.streamify body and merge blocks again.
183188
Block *body = streamify_op.getBody();
184189
body->getOperations().splice(body->begin(), block->getOperations(),
185190
op_range.begin(), op_range.end());
186191
rewriter.mergeBlocks(epilogue, block, streamify_op->getResults());
187192

188-
legal_begin = nullptr; // Start of illegal op sequence.
193+
illegal_begin = nullptr; // Start of legal op sequence.
189194
result = success();
190195
}
191196
return result;
@@ -275,16 +280,47 @@ LogicalResult ConvertOpTypesPattern<OpTy>::matchAndRewrite(
275280
return success();
276281
}
277282

278-
void populateStreamifyConversionPatterns(RewritePatternSet &patterns,
279-
TypeConverter &converter,
280-
ConversionTarget &target) {
281-
// Wrap tfrt.call ops to provide chain and stream which may be added to the
282-
// callee's arguments. This adds a chain and stream argument to all functions
283-
// containing such tfrt.call. If this turns out to be a problem, we need to
284-
// analyze the call graph and only wrap calls that execute ops implementing
285-
// the AsyncOpInterface.
286-
target.addLegalOp<compiler::CallOp, compiler::WhileOp, memref::LoadOp>();
283+
namespace {
284+
285+
struct StreamifyOpsPass
286+
: public mlir::PassWrapper<StreamifyOpsPass, OperationPass<func::FuncOp>> {
287+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StreamifyOpsPass)
288+
289+
StreamifyOpsPass() = default;
290+
StreamifyOpsPass(const StreamifyOpsPass &) {}
291+
292+
ListOption<std::string> op_names = {*this, "ops",
293+
llvm::cl::desc("illegal op names")};
294+
295+
private:
296+
StringRef getArgument() const final { return "tfrt-streamify-ops"; }
297+
298+
void getDependentDialects(DialectRegistry &registry) const override {
299+
registry.insert<compiler::TFRTDialect, GpuDialect>();
300+
}
301+
302+
void runOnOperation() override;
303+
};
304+
305+
} // namespace
306+
307+
void StreamifyOpsPass::runOnOperation() {
308+
RewritePatternSet patterns(&getContext());
309+
patterns.add<StreamifyOpsPattern>(&getContext(), op_names);
310+
if (failed(applyOpPatternsAndFold(getOperation(), std::move(patterns))))
311+
return signalPassFailure();
312+
}
313+
314+
std::unique_ptr<OperationPass<func::FuncOp>> CreateStreamifyOpsPass(
315+
ArrayRef<std::string> op_names) {
316+
auto pass = std::make_unique<StreamifyOpsPass>();
317+
std::vector<std::string> &vector = pass->op_names;
318+
llvm::copy(op_names, std::back_inserter(vector));
319+
return pass;
320+
}
287321

322+
void PopulateMemrefConversionPatterns(RewritePatternSet &patterns,
323+
TypeConverter &converter) {
288324
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
289325
converter);
290326
populateCallOpTypeConversionPattern(patterns, converter);
@@ -294,8 +330,6 @@ void populateStreamifyConversionPatterns(RewritePatternSet &patterns,
294330
ConvertOpTypesPattern<compiler::WhileOp>>(converter,
295331
patterns.getContext());
296332

297-
patterns.add<NestLegalOpsInStreamifyOpPattern>(patterns.getContext(), target);
298-
299333
patterns.add<FoldMemrefViewPattern, FoldMemrefReinterpretCastPattern,
300334
RewriteMemrefAllocPattern<memref::AllocOp>,
301335
RewriteMemrefAllocPattern<memref::AllocaOp>>(

backends/gpu/lib/passes/gpu_to_tfrt_passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,7 @@ void populateGpuToTfrtGpuPasses(OpPassManager &pm) {
17361736
}
17371737

17381738
void registerPasses() {
1739+
registerPass([] { return CreateStreamifyOpsPass({}); });
17391740
PassRegistration<AddChainAndStreamToFuncPass>();
17401741
PassRegistration<ConvertAsyncToChainAndEventPass>();
17411742
PassRegistration<ConvertGpuToTfrtGpuPass>();

backends/gpu/mlir_tests/conversion/async_conversion.mlir

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
// RUN: tfrt_gpu_opt %s \
16-
// RUN: -test-streamify-conversion \
16+
// RUN: -tfrt-streamify-ops=ops=wrap.op,tfrt.call \
1717
// RUN: -allow-unregistered-dialect \
1818
// RUN: | FileCheck %s
1919

@@ -61,39 +61,42 @@ func.func @test_wrap_streamify() {
6161
func.func private @returns_values() -> (f32, f32)
6262
func.func private @takes_argument(%arg0: f32)
6363

64+
// Note on 'DISABLED' below: There is temporarily no pass in TFRT GPU that
65+
// exercises the memref conversion patterns. TODO(csigg): Enable checks again.
66+
6467
// CHECK-LABEL: @test_fold_memref_view
6568
func.func @test_fold_memref_view(%arg0: memref<64xi8>) -> memref<4x4xf32> {
6669
%zero = arith.constant 0 : index
67-
// CHECK-NOT: memref.view
68-
// CHECK: %[[buffer:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<64xi8> to !tfrt_gpu.buffer
69-
// CHECK: %[[memref:.*]] = builtin.unrealized_conversion_cast %[[buffer]] : !tfrt_gpu.buffer to memref<4x4xf32>
70+
// DISABLED-NOT: memref.view
71+
// DISABLED: %[[buffer:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<64xi8> to !tfrt_gpu.buffer
72+
// DISABLED: %[[memref:.*]] = builtin.unrealized_conversion_cast %[[buffer]] : !tfrt_gpu.buffer to memref<4x4xf32>
7073
%view = memref.view %arg0[%zero][] : memref<64xi8> to memref<4x4xf32>
71-
// CHECK: return %[[memref]]
74+
// DISABLED: return %[[memref]]
7275
func.return %view : memref<4x4xf32>
7376
}
7477

7578
// CHECK-LABEL: @test_fold_memref_cast
7679
func.func @test_fold_memref_cast(%arg0: memref<64xi8>) -> memref<8x8xi8> {
77-
// CHECK-NOT: memref.reinterpret_cast
78-
// CHECK: %[[buffer:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<64xi8> to !tfrt_gpu.buffer
79-
// CHECK: %[[memref:.*]] = builtin.unrealized_conversion_cast %[[buffer]] : !tfrt_gpu.buffer to memref<8x8xi8>
80+
// DISABLED-NOT: memref.reinterpret_cast
81+
// DISABLED: %[[buffer:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<64xi8> to !tfrt_gpu.buffer
82+
// DISABLED: %[[memref:.*]] = builtin.unrealized_conversion_cast %[[buffer]] : !tfrt_gpu.buffer to memref<8x8xi8>
8083
%cast = memref.reinterpret_cast %arg0
8184
to offset: [0], sizes: [8, 8], strides: [8, 1]
8285
: memref<64xi8> to memref<8x8xi8>
83-
// CHECK: return %[[memref]]
86+
// DISABLED: return %[[memref]]
8487
func.return %cast : memref<8x8xi8>
8588
}
8689

8790
// CHECK-LABEL: @test_rewrite_alloc
8891
func.func @test_rewrite_alloc() {
89-
// CHECK: %[[memref:.*]] = gpu.alloc () : memref<64xi8>
92+
// DISABLED: %[[memref:.*]] = gpu.alloc () : memref<64xi8>
9093
%memref = memref.alloc() : memref<64xi8>
91-
// CHECK: "other.op"() : () -> ()
94+
// DISABLED: "other.op"() : () -> ()
9295
"other.op"() : () -> ()
93-
// CHECK: gpu.dealloc %[[memref]] : memref<64xi8>
96+
// DISABLED: gpu.dealloc %[[memref]] : memref<64xi8>
9497
memref.dealloc %memref : memref<64xi8>
95-
// CHECK: %[[tmp:.*]] = gpu.alloc () : memref<64xi8>
98+
// DISABLED: %[[tmp:.*]] = gpu.alloc () : memref<64xi8>
9699
%temp = memref.alloca() : memref<64xi8>
97-
// CHECK: return
100+
// DISABLED: return
98101
func.return
99102
}

backends/gpu/tools/tfrt_gpu_opt/tfrt_gpu_opt.cc

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -42,68 +42,6 @@
4242
#include "tfrt/support/error_util.h"
4343
#include "tfrt/test_kernels/opdefs/test_kernels.h"
4444

45-
namespace {
46-
47-
// Test pass to wrap tfrt_gpu ops in tfrt_gpu.streamify.
48-
struct TestStreamifyConversionPass
49-
: public mlir::PassWrapper<TestStreamifyConversionPass,
50-
OperationPass<mlir::func::FuncOp>> {
51-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStreamifyConversionPass)
52-
53-
StringRef getArgument() const final { return "test-streamify-conversion"; }
54-
55-
void getDependentDialects(DialectRegistry &registry) const override {
56-
tfrt::RegisterTFRTDialects(registry);
57-
tfrt::RegisterTFRTCompiledDialects(registry);
58-
registry.insert<tfrt::gpu::GpuDialect, mlir::arith::ArithmeticDialect,
59-
mlir::cf::ControlFlowDialect, mlir::gpu::GPUDialect,
60-
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
61-
tfrt::compiler::TFRTDialect>();
62-
}
63-
64-
void runOnOperation() override {
65-
TypeConverter converter;
66-
converter.addConversion([](Type type) { return type; });
67-
auto buffer_type = tfrt::gpu::BufferType::get(&getContext());
68-
converter.addConversion([&](BaseMemRefType) { return buffer_type; });
69-
converter.addTargetMaterialization([](OpBuilder &builder, Type type,
70-
ValueRange inputs,
71-
Location loc) -> Value {
72-
return builder.create<mlir::UnrealizedConversionCastOp>(loc, type, inputs)
73-
.getResult(0);
74-
});
75-
converter.addSourceMaterialization([](OpBuilder &builder, Type type,
76-
ValueRange inputs,
77-
Location loc) -> Value {
78-
return builder.create<mlir::UnrealizedConversionCastOp>(loc, type, inputs)
79-
.getResult(0);
80-
});
81-
82-
ConversionTarget wrap(getContext());
83-
wrap.addLegalDialect("wrap");
84-
85-
RewritePatternSet patterns(&getContext());
86-
tfrt::gpu::populateStreamifyConversionPatterns(patterns, converter, wrap);
87-
88-
ConversionTarget target(getContext());
89-
target
90-
.addLegalDialect<mlir::gpu::GPUDialect, tfrt::compiler::TFRTDialect>();
91-
target.addLegalDialect("other");
92-
target.addLegalOp<mlir::UnrealizedConversionCastOp>();
93-
target.addLegalOp<tfrt::gpu::StreamifyOp>();
94-
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
95-
[&](mlir::func::FuncOp op) {
96-
return none_of(op.getBody().getOps(),
97-
[&](Operation &op) { return wrap.isLegal(&op); });
98-
});
99-
if (failed(applyPartialConversion(getOperation(), target,
100-
std::move(patterns))))
101-
return signalPassFailure();
102-
}
103-
};
104-
105-
} // namespace
106-
10745
int main(int argc, char **argv) {
10846
mlir::DialectRegistry registry;
10947
tfrt::RegisterTFRTDialects(registry);
@@ -112,7 +50,6 @@ int main(int argc, char **argv) {
11250
mlir::gpu::GPUDialect, mlir::memref::MemRefDialect,
11351
tfrt::compiler::TFRTDialect, tfrt::gpu::GpuDialect,
11452
tfrt::test::TestDialect>();
115-
PassRegistration<TestStreamifyConversionPass>();
11653
tfrt::gpu::registerPasses();
11754

11855
return mlir::asMainReturnCode(

0 commit comments

Comments
 (0)