Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/triton-llvm-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ int main(int argc, char **argv) {
}
// If we are supposed to override the target triple or data layout, do so now.
if (!TargetTriple.empty())
M->setTargetTriple(Triple::normalize(TargetTriple));
M->setTargetTriple(Triple(Triple::normalize(TargetTriple)));
auto optPipeline = makeOptimizingPipeline();
if (auto err = optPipeline(M.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3963d2148292145543cf83b13ff839a63995fdc2
2619c2ed584cdf3b38e6743ed3c785223f06e3f7
24 changes: 12 additions & 12 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct MoveSplatAfterElementwisePattern
MoveSplatAfterElementwisePattern(MLIRContext *context)
: OpTraitRewritePattern(context) {}

LogicalResult match(Operation *op) const override {
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
if (!isMemoryEffectFree(op)) {
return failure();
}
Expand All @@ -57,10 +57,10 @@ struct MoveSplatAfterElementwisePattern
return failure();
}
}
return success(op->getNumOperands() > 0);
}
//return success(op->getNumOperands() > 0);
if(op->getNumOperands() <= 0)
return failure();

void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto operands = op->getOperands();

Expand Down Expand Up @@ -96,6 +96,7 @@ struct MoveSplatAfterElementwisePattern
newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
return success();
}
};

Expand All @@ -108,7 +109,7 @@ struct MoveBroadcastAfterElementwisePattern
MoveBroadcastAfterElementwisePattern(MLIRContext *context)
: OpTraitRewritePattern(context) {}

LogicalResult match(Operation *op) const override {
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
if (!isMemoryEffectFree(op)) {
return failure();
}
Expand Down Expand Up @@ -137,14 +138,12 @@ struct MoveBroadcastAfterElementwisePattern
return failure();
}
}
return success(seenBroadcast);
}
if (!seenBroadcast)
return failure();

void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto loc = op->getLoc();

// Find broadcast op
auto operands = op->getOperands();
BroadcastOp broadcastOp;
for (auto operand : operands) {
broadcastOp = operand.getDefiningOp<BroadcastOp>();
Expand All @@ -154,7 +153,7 @@ struct MoveBroadcastAfterElementwisePattern
}

auto srcTy = broadcastOp.getSrc().getType();
auto srcShape = srcTy.getShape();
auto bcSrcShape = srcTy.getShape();
auto srcEncoding = srcTy.getEncoding();

// Reshape operands to match srcShape
Expand All @@ -167,7 +166,7 @@ struct MoveBroadcastAfterElementwisePattern
}
auto elemTy =
dyn_cast<RankedTensorType>(operand.getType()).getElementType();
auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding);
auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding);
if (auto splatOp = llvm::dyn_cast<SplatOp>(definingOp)) {
auto newSplat = rewriter.create<SplatOp>(loc, newTy, splatOp.getSrc());
newOperands.push_back(newSplat);
Expand All @@ -192,7 +191,7 @@ struct MoveBroadcastAfterElementwisePattern
for (auto resultTy : resultTypes) {
auto elemTy = dyn_cast<RankedTensorType>(resultTy).getElementType();
newResultTypes.push_back(
RankedTensorType::get(srcShape, elemTy, srcEncoding));
RankedTensorType::get(bcSrcShape, elemTy, srcEncoding));
}

// Create new op and broadcast results
Expand All @@ -203,6 +202,7 @@ struct MoveBroadcastAfterElementwisePattern
newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
return success();
}
};

Expand Down
8 changes: 4 additions & 4 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ createTargetMachine(llvm::Module *module, std::string proc,
bool enable_fp_fusion, const std::string &features) {
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetRegistry::lookupTarget(module->getTargetTriple().str(), error);
llvm::TargetOptions opt;
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
if (enable_fp_fusion)
Expand All @@ -59,7 +59,7 @@ createTargetMachine(llvm::Module *module, std::string proc,
opt.MCOptions.AsmVerbose = true;
opt.MCOptions.PreserveAsmComments = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt,
disableLLVMOpt ? llvm::CodeGenOptLevel::None
: llvm::CodeGenOptLevel::Aggressive)};
Expand Down Expand Up @@ -132,7 +132,7 @@ std::string translateLLVMIRToASM(llvm::Module &module,
// module->print(llvm::outs(), nullptr);

// create machine
module.setTargetTriple(triple);
module.setTargetTriple(Triple(triple));
auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features);
// set data layout
module.setDataLayout(machine->createDataLayout());
Expand Down Expand Up @@ -459,7 +459,7 @@ void init_triton_llvm(py::module &&m) {
std::string message = "Failed to parse library at " + path;
throw std::invalid_argument(message);
}
libMod->setTargetTriple(dstMod->getTargetTriple());
libMod->setTargetTriple(Triple(dstMod->getTargetTriple()));
libMod->setDataLayout(dstMod->getDataLayout());

std::unordered_set<std::string> externalFns;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Fp8>(loc, rewriter, v[0], v[1]);
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Fp8>(loc, rewriter, v[0], v[1]);
}

static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Bf8>(loc, rewriter, v[0], v[1]);
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Bf8>(loc, rewriter, v[0], v[1]);
}

template <typename convertOp>
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void init_triton_amd(py::module &&m) {
});

m.def("attach_target_triple",
[](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); });
[](llvm::Module *module) { module->setTargetTriple(llvm::Triple(amdTargetTriple)); });

// Set target architecture ISA version
m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) {
Expand Down
Loading