Skip to content

Commit 21f36ae

Browse files
chsiggcopybara-github
authored andcommitted
Wrap tfrt::gpu::setEntryPoint() into a pass.
Work around custom op parsing assigning loc to operand instead of op after https://reviews.llvm.org/D124188. PiperOrigin-RevId: 444468906
1 parent 629b07e commit 21f36ae

File tree

7 files changed

+87
-63
lines changed

7 files changed

+87
-63
lines changed

backends/gpu/include/tfrt/gpu/passes/passes.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,12 @@ void populateGpuToTfrtGpuPasses(mlir::OpPassManager& pm);
9696
// Registers all tfrt gpu passes.
9797
void registerPasses();
9898

99-
// Adds a function to `module` which returns the entry point information for
100-
// the gpu executor.
101-
void setEntryPoint(mlir::ModuleOp module, wrapper::Platform platform,
102-
llvm::StringRef function_name,
103-
llvm::ArrayRef<int64_t> buffer_sizes);
99+
// Creates a pass which adds a function returning the entry point information
100+
// for the gpu executor.
101+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateSetEntryPointPass();
102+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateSetEntryPointPass(
103+
wrapper::Platform platform, mlir::StringRef function_name,
104+
mlir::ArrayRef<int64_t> buffer_sizes);
104105

105106
} // namespace gpu
106107
} // namespace tfrt

backends/gpu/lib/passes/gpu_to_tfrt_passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,7 @@ void registerPasses() {
17421742
PassRegistration<ReconcileCastsPass>();
17431743
PassRegistration<ConvertAsyncToTfrtPass>();
17441744
PassRegistration<HoistingPass>();
1745+
registerPass([] { return CreateSetEntryPointPass(); });
17451746

17461747
PassPipelineRegistration<>(
17471748
"gpu-to-tfrt-gpu",

backends/gpu/lib/passes/set_entry_point.cc

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,41 @@
3131
namespace tfrt {
3232
namespace gpu {
3333

34-
void setEntryPoint(ModuleOp module, wrapper::Platform platform,
35-
StringRef function_name, ArrayRef<int64_t> buffer_sizes) {
34+
namespace {
35+
36+
struct SetEntryPointPass
37+
: public PassWrapper<SetEntryPointPass, OperationPass<ModuleOp>> {
38+
public:
39+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SetEntryPointPass)
40+
41+
SetEntryPointPass() = default;
42+
SetEntryPointPass(const SetEntryPointPass &) {}
43+
44+
Option<wrapper::Platform> platform{
45+
*this, "platform",
46+
llvm::cl::values(
47+
llvm::cl::OptionEnumValue{
48+
"CUDA", static_cast<int>(wrapper::Platform::CUDA), ""},
49+
llvm::cl::OptionEnumValue{
50+
"ROCm", static_cast<int>(wrapper::Platform::ROCm), ""})};
51+
Option<std::string> function_name{*this, "function_name"};
52+
ListOption<int64_t> buffer_sizes{*this, "buffer_sizes"};
53+
54+
private:
55+
StringRef getArgument() const final { return "tfrt-set-entry-point"; }
56+
57+
void getDependentDialects(DialectRegistry &registry) const override {
58+
registry.insert<compiler::TFRTDialect, GpuDialect>();
59+
}
60+
61+
void runOnOperation() override;
62+
};
63+
64+
} // namespace
65+
66+
static void SetEntryPoint(ModuleOp module, wrapper::Platform platform,
67+
StringRef function_name,
68+
ArrayRef<int64_t> buffer_sizes) {
3669
OpBuilder builder(module.getContext());
3770

3871
// Create a function.
@@ -65,5 +98,46 @@ void setEntryPoint(ModuleOp module, wrapper::Platform platform,
6598
builder.create<compiler::ReturnOp>(loc, get_entry_point_op->getResult(0));
6699
}
67100

101+
void SetEntryPointPass::runOnOperation() {
102+
if (!platform.hasValue()) {
103+
getOperation()->emitError() << "Unspecified 'platform' option";
104+
return signalPassFailure();
105+
}
106+
107+
func::FuncOp func_op;
108+
if (function_name.hasValue()) {
109+
func_op = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
110+
getOperation(), StringAttr::get(&getContext(), function_name));
111+
if (!func_op) {
112+
getOperation()->emitError()
113+
<< "Function '" << function_name << "' not found";
114+
return signalPassFailure();
115+
}
116+
} else {
117+
auto funcs = getOperation().getOps<func::FuncOp>();
118+
if (funcs.empty() || ++funcs.begin() != funcs.end()) {
119+
getOperation()->emitError() << "Expected exactly one function";
120+
return signalPassFailure();
121+
}
122+
func_op = *funcs.begin();
123+
}
124+
125+
SetEntryPoint(getOperation(), platform, func_op.getSymName(), buffer_sizes);
126+
}
127+
128+
std::unique_ptr<OperationPass<ModuleOp>> CreateSetEntryPointPass() {
129+
return std::make_unique<SetEntryPointPass>();
130+
}
131+
132+
std::unique_ptr<OperationPass<ModuleOp>> CreateSetEntryPointPass(
133+
wrapper::Platform platform, StringRef function_name,
134+
ArrayRef<int64_t> buffer_sizes) {
135+
auto pass = std::make_unique<SetEntryPointPass>();
136+
pass->platform = platform;
137+
pass->function_name = function_name.str();
138+
pass->buffer_sizes = buffer_sizes;
139+
return pass;
140+
}
141+
68142
} // namespace gpu
69143
} // namespace tfrt

backends/gpu/mlir_tests/conversion/entry_point.mlir renamed to backends/gpu/mlir_tests/conversion/set_entry_point.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
// RUN: tfrt_gpu_opt \
16-
// RUN: -test-set-entry-point='platform=CUDA buffer_sizes=1,2,3' %s \
16+
// RUN: -tfrt-set-entry-point='platform=CUDA buffer_sizes=1,2,3' %s \
1717
// RUN: | FileCheck %s
1818

1919
// CHECK: func @get_tfrt_gpu_entry_point() -> !tfrt_gpu.entry_point {

backends/gpu/mlir_tests/cuda/error.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
// XFAIL: *
16-
// RUN: tfrt_gpu_opt -mlir-print-debuginfo \
17-
// RUN: -test-set-entry-point='platform=CUDA function_name=error' %s \
15+
// RUN: tfrt_gpu_opt %s -mlir-print-debuginfo -mlir-print-op-generic \
16+
// RUN: -tfrt-set-entry-point='platform=CUDA function_name=error' \
1817
// RUN: | tfrt_gpu_translate -mlir-to-bef \
1918
// RUN: | tfrt_gpu_executor
2019

backends/gpu/mlir_tests/cuda/executor.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
// RUN: tfrt_gpu_opt -mlir-print-debuginfo \
16-
// RUN: -test-set-entry-point='platform=CUDA buffer_sizes=64' %s \
16+
// RUN: -tfrt-set-entry-point='platform=CUDA buffer_sizes=64' %s \
1717
// RUN: | tfrt_gpu_translate -mlir-to-bef \
1818
// RUN: | tfrt_gpu_executor \
1919
// RUN: | FileCheck %s

backends/gpu/tools/tfrt_gpu_opt/tfrt_gpu_opt.cc

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -102,56 +102,6 @@ struct TestStreamifyConversionPass
102102
}
103103
};
104104

105-
class TestSetEntryPointPass
106-
: public mlir::PassWrapper<TestSetEntryPointPass, OperationPass<ModuleOp>> {
107-
public:
108-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSetEntryPointPass)
109-
110-
TestSetEntryPointPass() = default;
111-
TestSetEntryPointPass(const TestSetEntryPointPass &pass) {}
112-
void getDependentDialects(DialectRegistry &registry) const override {
113-
tfrt::RegisterTFRTDialects(registry);
114-
tfrt::RegisterTFRTCompiledDialects(registry);
115-
registry.insert<tfrt::gpu::GpuDialect, mlir::arith::ArithmeticDialect,
116-
mlir::cf::ControlFlowDialect, mlir::gpu::GPUDialect,
117-
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
118-
tfrt::compiler::TFRTDialect>();
119-
}
120-
121-
StringRef getArgument() const final { return "test-set-entry-point"; }
122-
123-
void runOnOperation() override {
124-
auto platform = tfrt::gpu::wrapper::ParsePlatform(platform_);
125-
if (!platform) return emitError(toString(platform.takeError()));
126-
127-
mlir::func::FuncOp func_op;
128-
if (function_name_.hasValue()) {
129-
func_op = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(
130-
getOperation(), mlir::StringAttr::get(&getContext(), function_name_));
131-
if (!func_op)
132-
return emitError("Function '" + function_name_ + "' not found");
133-
} else {
134-
auto funcs = getOperation().getOps<mlir::func::FuncOp>();
135-
if (funcs.empty() || ++funcs.begin() != funcs.end())
136-
return emitError("Expected exactly one function");
137-
func_op = *funcs.begin();
138-
}
139-
140-
tfrt::gpu::setEntryPoint(getOperation(), *platform, func_op.getSymName(),
141-
buffer_sizes_);
142-
}
143-
144-
private:
145-
void emitError(StringRef message) {
146-
getOperation()->emitError() << message;
147-
signalPassFailure();
148-
}
149-
150-
Option<std::string> platform_{*this, "platform"};
151-
Option<std::string> function_name_{*this, "function_name"};
152-
ListOption<int64_t> buffer_sizes_{*this, "buffer_sizes"};
153-
};
154-
155105
} // namespace
156106

157107
int main(int argc, char **argv) {
@@ -163,7 +113,6 @@ int main(int argc, char **argv) {
163113
tfrt::compiler::TFRTDialect, tfrt::gpu::GpuDialect,
164114
tfrt::test::TestDialect>();
165115
PassRegistration<TestStreamifyConversionPass>();
166-
PassRegistration<TestSetEntryPointPass>();
167116
tfrt::gpu::registerPasses();
168117

169118
return mlir::asMainReturnCode(

0 commit comments

Comments
 (0)