diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp index 1ec804cb5042..3beeeabdc188 100644 --- a/bin/triton-llvm-opt.cpp +++ b/bin/triton-llvm-opt.cpp @@ -91,7 +91,7 @@ int main(int argc, char **argv) { } // If we are supposed to override the target triple or data layout, do so now. if (!TargetTriple.empty()) - M->setTargetTriple(Triple::normalize(TargetTriple)); + M->setTargetTriple(Triple(Triple::normalize(TargetTriple))); auto optPipeline = makeOptimizingPipeline(); if (auto err = optPipeline(M.get())) { llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 371dd55656bb..18b33a2a9f4c 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -3963d2148292145543cf83b13ff839a63995fdc2 +2619c2ed584cdf3b38e6743ed3c785223f06e3f7 diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index db8166085296..a445e282aece 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -43,7 +43,8 @@ struct MoveSplatAfterElementwisePattern MoveSplatAfterElementwisePattern(MLIRContext *context) : OpTraitRewritePattern(context) {} - LogicalResult match(Operation *op) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { if (!isMemoryEffectFree(op)) { return failure(); } @@ -57,10 +58,10 @@ struct MoveSplatAfterElementwisePattern return failure(); } } - return success(op->getNumOperands() > 0); - } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + if (op->getNumOperands() <= 0) + return failure(); + auto loc = op->getLoc(); auto operands = op->getOperands(); @@ -96,6 +97,7 @@ struct MoveSplatAfterElementwisePattern newOp->getResult(iRes)); rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); } + return success(); } }; @@ -108,7 +110,8 @@ struct MoveBroadcastAfterElementwisePattern MoveBroadcastAfterElementwisePattern(MLIRContext *context) : OpTraitRewritePattern(context) {} - LogicalResult match(Operation *op) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { if (!isMemoryEffectFree(op)) { return failure(); } @@ -137,14 +140,12 @@ struct MoveBroadcastAfterElementwisePattern return failure(); } } - return success(seenBroadcast); - } + if (!seenBroadcast) + return failure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); // Find broadcast op - auto operands = op->getOperands(); BroadcastOp broadcastOp; for (auto operand : operands) { broadcastOp = operand.getDefiningOp(); @@ -154,7 +155,7 @@ struct MoveBroadcastAfterElementwisePattern } auto srcTy = broadcastOp.getSrc().getType(); - auto srcShape = srcTy.getShape(); + auto bcSrcShape = srcTy.getShape(); auto srcEncoding = srcTy.getEncoding(); // Reshape operands to match srcShape @@ -167,7 +168,7 @@ struct MoveBroadcastAfterElementwisePattern } auto elemTy = dyn_cast(operand.getType()).getElementType(); - auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding); + auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding); if (auto splatOp = llvm::dyn_cast(definingOp)) { auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); newOperands.push_back(newSplat); @@ -192,7 +193,7 @@ struct MoveBroadcastAfterElementwisePattern for (auto resultTy : resultTypes) { auto elemTy = dyn_cast(resultTy).getElementType(); newResultTypes.push_back( - RankedTensorType::get(srcShape, elemTy, srcEncoding)); + RankedTensorType::get(bcSrcShape, elemTy, srcEncoding)); } // Create new op and broadcast results @@ -203,6 +204,7 @@ struct MoveBroadcastAfterElementwisePattern newOp->getResult(iRes)); rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); } + return success(); } }; diff --git a/python/src/llvm.cc b/python/src/llvm.cc index c86bf671a7df..ed8ec309699e 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -46,8 +46,8 @@ std::unique_ptr createTargetMachine(llvm::Module *module, std::string proc, bool enable_fp_fusion, const std::string &features) { std::string error; - auto target = - llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + auto target = llvm::TargetRegistry::lookupTarget( + module->getTargetTriple().str(), error); llvm::TargetOptions opt; bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); if (enable_fp_fusion) @@ -59,7 +59,7 @@ createTargetMachine(llvm::Module *module, std::string proc, opt.MCOptions.AsmVerbose = true; opt.MCOptions.PreserveAsmComments = true; std::unique_ptr machine{target->createTargetMachine( - module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_, std::nullopt, disableLLVMOpt ? llvm::CodeGenOptLevel::None : llvm::CodeGenOptLevel::Aggressive)}; @@ -132,7 +132,7 @@ std::string translateLLVMIRToASM(llvm::Module &module, // module->print(llvm::outs(), nullptr); // create machine - module.setTargetTriple(triple); + module.setTargetTriple(Triple(triple)); auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); // set data layout module.setDataLayout(machine->createDataLayout()); @@ -459,7 +459,7 @@ void init_triton_llvm(py::module &&m) { std::string message = "Failed to parse library at " + path; throw std::invalid_argument(message); } - libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setTargetTriple(Triple(dstMod->getTargetTriple())); libMod->setDataLayout(dstMod->getDataLayout()); std::unordered_set externalFns; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index c2f9c87b8e58..e30f69863b9a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -298,14 +298,16 @@ static SmallVector Fp8E4M3FN_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { assert(v.size() == 2); - return cvtScaleFp8ToFp32(loc, rewriter, v[0], v[1]); + return cvtScaleFp8ToFp32(loc, rewriter, v[0], + v[1]); } static SmallVector Fp8E5M2_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { assert(v.size() == 2); - return cvtScaleFp8ToFp32(loc, rewriter, v[0], v[1]); + return cvtScaleFp8ToFp32(loc, rewriter, v[0], + v[1]); } template diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 551da3f33c74..b97b5cb85292 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -122,8 +122,9 @@ void init_triton_amd(py::module &&m) { context.loadAllAvailableDialects(); }); - m.def("attach_target_triple", - [](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); }); + m.def("attach_target_triple", [](llvm::Module *module) { + module->setTargetTriple(llvm::Triple(amdTargetTriple)); + }); // Set target architecture ISA version m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) {