Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ test-unit: all
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -vvv python/test/unit/plugins/test_dialect_plugin.py
$(PYTEST) -s -vvv python/test/unit/plugins/test_dialect_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon

.PHONY: test-gluon
test-gluon: all
Expand Down
3 changes: 2 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
if (auto result = TP.getPassHandles(passNames); !result)
llvm::report_fatal_error(result.takeError());

std::vector<std::string> args;
for (const char *passName : passNames)
if (auto result = TP.registerPass(passName); !result)
if (auto result = TP.registerPass(passName, args); !result)
llvm::report_fatal_error(result.takeError());

std::vector<const char *> dialectNames;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,29 @@ def DialectPlugin_MagicOp : DialectPlugin_Op<"magic", [Pure,
}];
}

def DialectPlugin_FMagicOp : DialectPlugin_Op<"fmagic", [Pure,
SameOperandsAndResultType]> {
let summary = "Illustrates how to define a custom operation.";
let description = [{
The `plugin.fmagic` operation illustrates how to define a new
operation in a dialect. It uses an operation trait to declare that it
has no side effects.

Example:

```mlir
%0 = arith.constant 2 : i32
// Apply the magic operation to tensor %0 and return %1
%1 = plugin.fmagic %0 : i32
```
}];

let arguments = (ins FloatLike:$input);
let results = (outs FloatLike:$res);

let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
}

#endif // DIALECTPLUGIN_OPS
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>

namespace mlir {
Expand All @@ -12,12 +13,29 @@ class ModuleOp;

namespace triton {
namespace plugin {
class PluginTypeConverter : public TypeConverter {
public:
PluginTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
int numCTAs);
int getNumWarps() const { return numWarps; }
int getThreadsPerWarp() const { return threadsPerWarp; }
int getNumCTAs() const { return numCTAs; }

private:
MLIRContext *context;
int numWarps;
int threadsPerWarp;
int numCTAs;
};

#define GEN_PASS_DECL
#include "DialectPlugin/DialectPluginPasses.h.inc"

std::unique_ptr<OperationPass<ModuleOp>>
createConvertPluginGPUToLLVMPass(int32_t computeCapability = 80,
int32_t ptxVersion = 80);
std::unique_ptr<OperationPass<ModuleOp>> createConvertPluginGPUToTritonGPUPass(
int32_t num_warps = 4, int32_t threadsPerWarp = 32, int32_t numCTAs = 1);

#define GEN_PASS_REGISTRATION
#include "DialectPlugin/DialectPluginPasses.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,23 @@ def DialectPluginMagicOp: Pass<"convert-plugin-gpu-to-llvm", "mlir::ModuleOp"> {
];
}

def DialectPluginFMagicOp: Pass<"convert-plugin-gpu-to-triton-gpu", "mlir::ModuleOp"> {
let summary = "Converts PluginGPU Ops to TritonGPU Ops";
let constructor = "mlir::triton::plugin::createConvertPluginGPUToTritonGPUPass(4, 32, 1)";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::plugin::DialectPluginDialect"];
let options = [
Option<"num_warps", "num-warps",
"int32_t", /*default*/"4",
"Number of warps">,
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"Threads per warp">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"Number of CTAs">,
];
}

#endif // DIALECTPLUGIN_PASS
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRDialectPlugin
MLIRPass
LLVMSupport
MLIRSupport
MLIRArithDialect
TritonNVIDIAGPUToLLVM
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginOps.h"
#include "DialectPlugin/DialectPluginTypes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Tools/PluginUtils.h"
#include <cstdlib>

using namespace mlir;
using namespace mlir::triton::plugin;
Expand All @@ -19,79 +24,137 @@ void DialectPluginDialect::initialize() {
registerTypes();
}

#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginPasses.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"

#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginPasses.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "triton/Tools/PluginUtils.h"
#include "llvm/Config/llvm-config.h"

using namespace mlir;

static void addTritonPluginPass(mlir::PassManager *pm) {
static void addTritonPluginPass(mlir::PassManager *pm, int num_warps,
int threadsPerWarp, int numCTAs) {
pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass());
}

static void registerTritonPluginPass() {
static void registerTritonPluginPass(int num_warps, int threadsPerWarp,
int numCTAs) {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::triton::plugin::createConvertPluginGPUToLLVMPass();
});
}

static void addTritonPluginPass2(mlir::PassManager *pm, int num_warps,
int threadsPerWarp, int numCTAs) {
pm->addPass(mlir::triton::plugin::createConvertPluginGPUToTritonGPUPass(
num_warps, threadsPerWarp, numCTAs));
}

static void registerTritonPluginPass2(int num_warps, int threadsPerWarp,
int numCTAs) {
::mlir::registerPass([=]() -> std::unique_ptr<::mlir::Pass> {
return mlir::triton::plugin::createConvertPluginGPUToTritonGPUPass(
num_warps, threadsPerWarp, numCTAs);
});
}

static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
static const char *ADD_PLUGIN_FARITH_PASS_NAME = "plugingpu_farith_conversion";
static std::unordered_map<std::string,
void (*)(mlir::PassManager *, int, int, int)>
passMap = {{ADD_PLUGIN_FARITH_PASS_NAME, addTritonPluginPass2},
{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)(int, int, int)> registryMap = {
{ADD_PLUGIN_FARITH_PASS_NAME, registerTritonPluginPass2},
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME,
ADD_PLUGIN_FARITH_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
std::string passNameStr(passName);
tritonAddPluginPass(mlir::PassManager *pm, TRITON_PLUGIN_PASS_ARGS) {
int num_warps = 0;
int threadsPerWarp = 0;
int numCTAs = 0;
if (args.size() > 0) {
num_warps = std::atoi(args[0].c_str());
threadsPerWarp = std::atoi(args[1].c_str());
numCTAs = std::atoi(args[2].c_str());
}

std::string passNameStr(handle);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
passMap[passNameStr](pm, num_warps, threadsPerWarp, numCTAs);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
tritonRegisterPluginPass(TRITON_PLUGIN_PASS_ARGS) {
int num_warps = 0;
int threadsPerWarp = 0;
int numCTAs = 0;
if (args.size() > 0) {
num_warps = std::atoi(args[0].c_str());
threadsPerWarp = std::atoi(args[1].c_str());
numCTAs = std::atoi(args[2].c_str());
}

std::string passNameStr(handle);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
registryMap[passNameStr](num_warps, threadsPerWarp, numCTAs);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
tritonEnumeratePluginPasses(TRITON_PLUGIN_ENUMERATOR_ARGS) {
if (!count)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
assert(passMap.size() == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
*count = passMap.size();
if (!handles)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i] = passName;
handles[i++] = passName;
}
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginDialects(uint32_t *dialectCount,
const char **dialectNames) {
*dialectCount = 1;
if (!dialectNames)
tritonEnumeratePluginDialects(TRITON_PLUGIN_ENUMERATOR_ARGS) {
*count = 1;
if (!handles)
return TP_SUCCESS;
dialectNames[0] = "DialectPlugin";
handles[0] = "DialectPlugin";
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginCustomOps(TRITON_PLUGIN_ENUMERATOR_ARGS) {
if (!count)
return TP_GENERIC_FAILURE;
*count = 1;
if (!handles)
return TP_SUCCESS;
handles[0] = "create_custom_op";
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonAddPluginCustomOp(TRITON_PLUGIN_CUSTOM_OP_ARGS) {
::mlir::Value &dst = operands[0];
::mlir::Value &src = operands[1];

dst = self.create<mlir::triton::plugin::FMagicOp>(src);
operands[0] = dst;
return TP_SUCCESS;
}

Expand Down
Loading
Loading