Skip to content

Commit e350beb

Browse files
authored
Pulls in llvm/llvm-project#130300 for gfx950 type conversion ops. This update includes changes that store the parsed Triple in the module and deprecate `match` and `rewrite` functions.
1 parent 9ce44c8 commit e350beb

File tree

6 files changed

+28
-23
lines changed

6 files changed

+28
-23
lines changed

bin/triton-llvm-opt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ int main(int argc, char **argv) {
9191
}
9292
// If we are supposed to override the target triple or data layout, do so now.
9393
if (!TargetTriple.empty())
94-
M->setTargetTriple(Triple::normalize(TargetTriple));
94+
M->setTargetTriple(Triple(Triple::normalize(TargetTriple)));
9595
auto optPipeline = makeOptimizingPipeline();
9696
if (auto err = optPipeline(M.get())) {
9797
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3963d2148292145543cf83b13ff839a63995fdc2
1+
2619c2ed584cdf3b38e6743ed3c785223f06e3f7

lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ struct MoveSplatAfterElementwisePattern
4343
MoveSplatAfterElementwisePattern(MLIRContext *context)
4444
: OpTraitRewritePattern(context) {}
4545

46-
LogicalResult match(Operation *op) const override {
46+
LogicalResult matchAndRewrite(Operation *op,
47+
PatternRewriter &rewriter) const override {
4748
if (!isMemoryEffectFree(op)) {
4849
return failure();
4950
}
@@ -57,10 +58,10 @@ struct MoveSplatAfterElementwisePattern
5758
return failure();
5859
}
5960
}
60-
return success(op->getNumOperands() > 0);
61-
}
6261

63-
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
62+
if (op->getNumOperands() <= 0)
63+
return failure();
64+
6465
auto loc = op->getLoc();
6566
auto operands = op->getOperands();
6667

@@ -96,6 +97,7 @@ struct MoveSplatAfterElementwisePattern
9697
newOp->getResult(iRes));
9798
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
9899
}
100+
return success();
99101
}
100102
};
101103

@@ -108,7 +110,8 @@ struct MoveBroadcastAfterElementwisePattern
108110
MoveBroadcastAfterElementwisePattern(MLIRContext *context)
109111
: OpTraitRewritePattern(context) {}
110112

111-
LogicalResult match(Operation *op) const override {
113+
LogicalResult matchAndRewrite(Operation *op,
114+
PatternRewriter &rewriter) const override {
112115
if (!isMemoryEffectFree(op)) {
113116
return failure();
114117
}
@@ -137,14 +140,12 @@ struct MoveBroadcastAfterElementwisePattern
137140
return failure();
138141
}
139142
}
140-
return success(seenBroadcast);
141-
}
143+
if (!seenBroadcast)
144+
return failure();
142145

143-
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
144146
auto loc = op->getLoc();
145147

146148
// Find broadcast op
147-
auto operands = op->getOperands();
148149
BroadcastOp broadcastOp;
149150
for (auto operand : operands) {
150151
broadcastOp = operand.getDefiningOp<BroadcastOp>();
@@ -154,7 +155,7 @@ struct MoveBroadcastAfterElementwisePattern
154155
}
155156

156157
auto srcTy = broadcastOp.getSrc().getType();
157-
auto srcShape = srcTy.getShape();
158+
auto bcSrcShape = srcTy.getShape();
158159
auto srcEncoding = srcTy.getEncoding();
159160

160161
// Reshape operands to match srcShape
@@ -167,7 +168,7 @@ struct MoveBroadcastAfterElementwisePattern
167168
}
168169
auto elemTy =
169170
dyn_cast<RankedTensorType>(operand.getType()).getElementType();
170-
auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding);
171+
auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding);
171172
if (auto splatOp = llvm::dyn_cast<SplatOp>(definingOp)) {
172173
auto newSplat = rewriter.create<SplatOp>(loc, newTy, splatOp.getSrc());
173174
newOperands.push_back(newSplat);
@@ -192,7 +193,7 @@ struct MoveBroadcastAfterElementwisePattern
192193
for (auto resultTy : resultTypes) {
193194
auto elemTy = dyn_cast<RankedTensorType>(resultTy).getElementType();
194195
newResultTypes.push_back(
195-
RankedTensorType::get(srcShape, elemTy, srcEncoding));
196+
RankedTensorType::get(bcSrcShape, elemTy, srcEncoding));
196197
}
197198

198199
// Create new op and broadcast results
@@ -203,6 +204,7 @@ struct MoveBroadcastAfterElementwisePattern
203204
newOp->getResult(iRes));
204205
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
205206
}
207+
return success();
206208
}
207209
};
208210

python/src/llvm.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ std::unique_ptr<TargetMachine>
4646
createTargetMachine(llvm::Module *module, std::string proc,
4747
bool enable_fp_fusion, const std::string &features) {
4848
std::string error;
49-
auto target =
50-
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
49+
auto target = llvm::TargetRegistry::lookupTarget(
50+
module->getTargetTriple().str(), error);
5151
llvm::TargetOptions opt;
5252
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
5353
if (enable_fp_fusion)
@@ -59,7 +59,7 @@ createTargetMachine(llvm::Module *module, std::string proc,
5959
opt.MCOptions.AsmVerbose = true;
6060
opt.MCOptions.PreserveAsmComments = true;
6161
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
62-
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
62+
module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_,
6363
std::nullopt,
6464
disableLLVMOpt ? llvm::CodeGenOptLevel::None
6565
: llvm::CodeGenOptLevel::Aggressive)};
@@ -132,7 +132,7 @@ std::string translateLLVMIRToASM(llvm::Module &module,
132132
// module->print(llvm::outs(), nullptr);
133133

134134
// create machine
135-
module.setTargetTriple(triple);
135+
module.setTargetTriple(Triple(triple));
136136
auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features);
137137
// set data layout
138138
module.setDataLayout(machine->createDataLayout());
@@ -459,7 +459,7 @@ void init_triton_llvm(py::module &&m) {
459459
std::string message = "Failed to parse library at " + path;
460460
throw std::invalid_argument(message);
461461
}
462-
libMod->setTargetTriple(dstMod->getTargetTriple());
462+
libMod->setTargetTriple(Triple(dstMod->getTargetTriple()));
463463
libMod->setDataLayout(dstMod->getDataLayout());
464464

465465
std::unordered_set<std::string> externalFns;

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,16 @@ static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
298298
ConversionPatternRewriter &rewriter,
299299
const SmallVector<Value> &v) {
300300
assert(v.size() == 2);
301-
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Fp8>(loc, rewriter, v[0], v[1]);
301+
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Fp8>(loc, rewriter, v[0],
302+
v[1]);
302303
}
303304

304305
static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
305306
ConversionPatternRewriter &rewriter,
306307
const SmallVector<Value> &v) {
307308
assert(v.size() == 2);
308-
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Bf8>(loc, rewriter, v[0], v[1]);
309+
return cvtScaleFp8ToFp32<ROCDL::CvtScale32PkF32Bf8>(loc, rewriter, v[0],
310+
v[1]);
309311
}
310312

311313
template <typename convertOp>

third_party/amd/python/triton_amd.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ void init_triton_amd(py::module &&m) {
122122
context.loadAllAvailableDialects();
123123
});
124124

125-
m.def("attach_target_triple",
126-
[](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); });
125+
m.def("attach_target_triple", [](llvm::Module *module) {
126+
module->setTargetTriple(llvm::Triple(amdTargetTriple));
127+
});
127128

128129
// Set target architecture ISA version
129130
m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) {

0 commit comments

Comments
 (0)