1414#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
1515
1616#include " mlir/IR/Attributes.h"
17+ #include " llvm/IR/Module.h"
1718
1819namespace llvm {
1920class IRBuilderBase ;
@@ -52,7 +53,11 @@ class TargetOptions {
5253 StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
5354 StringRef cmdOptions = {},
5455 CompilationTarget compilationTarget = getDefaultCompilationTarget(),
55- function_ref<SymbolTable *()> getSymbolTableCallback = {});
56+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
57+ function_ref<void (llvm::Module &)> initialLlvmIRCallback = {},
58+ function_ref<void (llvm::Module &)> linkedLlvmIRCallback = {},
59+ function_ref<void (llvm::Module &)> optimizedLlvmIRCallback = {},
60+ function_ref<void (StringRef)> isaCallback = {});
5661
5762 // / Returns the typeID.
5863 TypeID getTypeID () const ;
@@ -80,6 +85,22 @@ class TargetOptions {
8085 // / table.
8186 SymbolTable *getSymbolTable () const ;
8287
88+ // / Returns the callback invoked with the initial LLVM IR for the device
89+ // / module.
90+ function_ref<void (llvm::Module &)> getInitialLlvmIRCallback () const ;
91+
92+ // / Returns the callback invoked with LLVM IR for the device module
93+ // / after linking the device libraries.
94+ function_ref<void (llvm::Module &)> getLinkedLlvmIRCallback () const ;
95+
96+ // / Returns the callback invoked with LLVM IR for the device module after
97+ // / LLVM optimizations but before codegen.
98+ function_ref<void (llvm::Module &)> getOptimizedLlvmIRCallback () const ;
99+
100+ // / Returns the callback invoked with the target ISA for the device,
101+ // / for example PTX assembly.
102+ function_ref<void (StringRef)> getISACallback () const ;
103+
83104 // / Returns the default compilation target: `CompilationTarget::Fatbin`.
84105 static CompilationTarget getDefaultCompilationTarget ();
85106
@@ -90,7 +111,11 @@ class TargetOptions {
90111 TypeID typeID, StringRef toolkitPath = {},
91112 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
92113 CompilationTarget compilationTarget = getDefaultCompilationTarget(),
93- function_ref<SymbolTable *()> getSymbolTableCallback = {});
114+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
115+ function_ref<void (llvm::Module &)> initialLlvmIRCallback = {},
116+ function_ref<void (llvm::Module &)> linkedLlvmIRCallback = {},
117+ function_ref<void (llvm::Module &)> optimizedLlvmIRCallback = {},
118+ function_ref<void (StringRef)> isaCallback = {});
94119
95120 // / Path to the target toolkit.
96121 std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
109134 // / being serialized.
110135 function_ref<SymbolTable *()> getSymbolTableCallback;
111136
137+ // / Callback invoked with the initial LLVM IR for the device module.
138+ function_ref<void (llvm::Module &)> initialLlvmIRCallback;
139+
140+ // / Callback invoked with LLVM IR for the device module after
141+ // / linking the device libraries.
142+ function_ref<void (llvm::Module &)> linkedLlvmIRCallback;
143+
144+ // / Callback invoked with LLVM IR for the device module after
145+ // / LLVM optimizations but before codegen.
146+ function_ref<void (llvm::Module &)> optimizedLlvmIRCallback;
147+
148+ // / Callback invoked with the target ISA for the device,
149+ // / for example PTX assembly.
150+ function_ref<void (StringRef)> isaCallback;
151+
112152private:
113153 TypeID typeID;
114154};
0 commit comments