Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/triton-llvm-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ int main(int argc, char **argv) {
}
// If we are supposed to override the target triple or data layout, do so now.
if (!TargetTriple.empty())
M->setTargetTriple(Triple::normalize(TargetTriple));
M->setTargetTriple(Triple(Triple::normalize(TargetTriple)));
auto optPipeline = makeOptimizingPipeline();
if (auto err = optPipeline(M.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3963d2148292145543cf83b13ff839a63995fdc2
2619c2ed584cdf3b38e6743ed3c785223f06e3f7
26 changes: 14 additions & 12 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct MoveSplatAfterElementwisePattern
MoveSplatAfterElementwisePattern(MLIRContext *context)
: OpTraitRewritePattern(context) {}

LogicalResult match(Operation *op) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isMemoryEffectFree(op)) {
return failure();
}
Expand All @@ -57,10 +58,10 @@ struct MoveSplatAfterElementwisePattern
return failure();
}
}
return success(op->getNumOperands() > 0);
}

void rewrite(Operation *op, PatternRewriter &rewriter) const override {
if (op->getNumOperands() <= 0)
return failure();

auto loc = op->getLoc();
auto operands = op->getOperands();

Expand Down Expand Up @@ -96,6 +97,7 @@ struct MoveSplatAfterElementwisePattern
newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
return success();
}
};

Expand All @@ -108,7 +110,8 @@ struct MoveBroadcastAfterElementwisePattern
MoveBroadcastAfterElementwisePattern(MLIRContext *context)
: OpTraitRewritePattern(context) {}

LogicalResult match(Operation *op) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isMemoryEffectFree(op)) {
return failure();
}
Expand Down Expand Up @@ -137,14 +140,12 @@ struct MoveBroadcastAfterElementwisePattern
return failure();
}
}
return success(seenBroadcast);
}
if (!seenBroadcast)
return failure();

void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto loc = op->getLoc();

// Find broadcast op
auto operands = op->getOperands();
BroadcastOp broadcastOp;
for (auto operand : operands) {
broadcastOp = operand.getDefiningOp<BroadcastOp>();
Expand All @@ -154,7 +155,7 @@ struct MoveBroadcastAfterElementwisePattern
}

auto srcTy = broadcastOp.getSrc().getType();
auto srcShape = srcTy.getShape();
auto bcSrcShape = srcTy.getShape();
auto srcEncoding = srcTy.getEncoding();

// Reshape operands to match srcShape
Expand All @@ -167,7 +168,7 @@ struct MoveBroadcastAfterElementwisePattern
}
auto elemTy =
dyn_cast<RankedTensorType>(operand.getType()).getElementType();
auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding);
auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding);
if (auto splatOp = llvm::dyn_cast<SplatOp>(definingOp)) {
auto newSplat = rewriter.create<SplatOp>(loc, newTy, splatOp.getSrc());
newOperands.push_back(newSplat);
Expand All @@ -192,7 +193,7 @@ struct MoveBroadcastAfterElementwisePattern
for (auto resultTy : resultTypes) {
auto elemTy = dyn_cast<RankedTensorType>(resultTy).getElementType();
newResultTypes.push_back(
RankedTensorType::get(srcShape, elemTy, srcEncoding));
RankedTensorType::get(bcSrcShape, elemTy, srcEncoding));
}

// Create new op and broadcast results
Expand All @@ -203,6 +204,7 @@ struct MoveBroadcastAfterElementwisePattern
newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
return success();
}
};

Expand Down
10 changes: 5 additions & 5 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ std::unique_ptr<TargetMachine>
createTargetMachine(llvm::Module *module, std::string proc,
bool enable_fp_fusion, const std::string &features) {
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
auto target = llvm::TargetRegistry::lookupTarget(
module->getTargetTriple().str(), error);
llvm::TargetOptions opt;
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
if (enable_fp_fusion)
Expand All @@ -59,7 +59,7 @@ createTargetMachine(llvm::Module *module, std::string proc,
opt.MCOptions.AsmVerbose = true;
opt.MCOptions.PreserveAsmComments = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt,
disableLLVMOpt ? llvm::CodeGenOptLevel::None
: llvm::CodeGenOptLevel::Aggressive)};
Expand Down Expand Up @@ -132,7 +132,7 @@ std::string translateLLVMIRToASM(llvm::Module &module,
// module->print(llvm::outs(), nullptr);

// create machine
module.setTargetTriple(triple);
module.setTargetTriple(Triple(triple));
auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features);
// set data layout
module.setDataLayout(machine->createDataLayout());
Expand Down Expand Up @@ -459,7 +459,7 @@ void init_triton_llvm(py::module &&m) {
std::string message = "Failed to parse library at " + path;
throw std::invalid_argument(message);
}
libMod->setTargetTriple(dstMod->getTargetTriple());
libMod->setTargetTriple(Triple(dstMod->getTargetTriple()));
libMod->setDataLayout(dstMod->getDataLayout());

std::unordered_set<std::string> externalFns;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,16 @@ static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Fp8>(loc, rewriter, v[0], v[1]);
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Fp8>(loc, rewriter, v[0],
v[1]);
}

static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Bf8>(loc, rewriter, v[0], v[1]);
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Bf8>(loc, rewriter, v[0],
v[1]);
}

template <typename convertOp>
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ void init_triton_amd(py::module &&m) {
context.loadAllAvailableDialects();
});

m.def("attach_target_triple",
[](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); });
m.def("attach_target_triple", [](llvm::Module *module) {
module->setTargetTriple(llvm::Triple(amdTargetTriple));
});

// Set target architecture ISA version
m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) {
Expand Down
Loading