Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
df0864e761107b07e38f5503e0cbee0cebb4c5e8
61f8a7f618901797ee8663389a29722f29216a96
8 changes: 7 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ using namespace mlir::triton;
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__)

// Types
#define int_ty(width) rewriter.getIntegerType(width)
Expand Down Expand Up @@ -228,6 +228,12 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);

LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
LLVMFuncOp funcOp, ValueRange args);
LLVM::CallIntrinsicOp
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
TypeRange types, ValueRange args);

// Is v an integer or floating-point scalar constant equal to 0?
bool isConstantZero(Value v);

Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
newCallOp.getProperties().setOpBundleSizes(
rewriter.getDenseI32ArrayAttr({}));
newCallOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(promotedOperands.size()), 0});
return newCallOp;
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ struct MulhiUIOpConversion
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

protected:
Expand Down Expand Up @@ -327,7 +327,7 @@ struct ExternElementwiseOpConversion
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath());
return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}
};

Expand Down
21 changes: 19 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -518,6 +517,24 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
builder.getIntegerAttr(ty, value));
}

LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
LLVMFuncOp funcOp, ValueRange args) {
auto op = builder.create<LLVM::CallOp>(loc, funcOp, args);
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
return op;
}

LLVM::CallIntrinsicOp
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
TypeRange types, ValueRange args) {
auto op = builder.create<LLVM::CallIntrinsicOp>(loc, types, args);
op.getProperties().setIntrin(builder.getStringAttr(intrinsic));
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
return op;
}

bool isConstantZero(Value v) {
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace mlir {
namespace triton {
Expand Down Expand Up @@ -187,11 +188,11 @@ class CallOpConversion : public mlir::RewritePattern {
rewriter.create<LLVM::FPToSIOp>(loc, returnType, op->getResult(0));
} else if (calleeName == "__triton_hip_fast_fdividef") {
assert(operands.size() == 2);
auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32");
LLVM::FastmathFlagsAttr defaultFlags{};
auto rcpOp = rewriter.create<LLVM::CallIntrinsicOp>(
loc, returnType, name, operands[1], defaultFlags);
const char *intrinsic = "llvm.amdgcn.rcp.f32";
auto rcpOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic,
returnType, operands[1]);

LLVM::FastmathFlagsAttr defaultFlags{};
replacementOp = rewriter.create<LLVM::FMulOp>(
loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags);
}
Expand Down
7 changes: 3 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "../PatternTritonGPUOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace mlir::triton::AMD {
namespace {
Expand Down Expand Up @@ -219,10 +220,8 @@ Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) {
operands.push_back(int_val(1, false));
}
auto wmmaIntrinsic = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name),
operands, defaultFlags);

auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, name, valC.getType(), operands);
return wmmaIntrinsic.getResult(0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ struct ExpOpConversionApprox
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {rewriter.create<LLVM::CallOp>(loc, funcOp, prod).getResult()};
return {LLVM::createLLVMCallOp(rewriter, loc, funcOp, prod).getResult()};
}
};

Expand Down Expand Up @@ -1276,7 +1276,7 @@ struct Exp2OpConversion
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

private:
Expand Down
20 changes: 9 additions & 11 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
InstructionKindMask maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier");
const char *intrinsicName = "llvm.amdgcn.sched.group.barrier";

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
Expand All @@ -47,36 +47,34 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
Value groupId = LLVM::createConstantI32(loc, rewriter,
static_cast<int32_t>(groupIdValue));

LLVM::FastmathFlagsAttr defaultFlags{};
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{},
ValueRange{mask, size, groupId});
}

// Insert intrinsic that controls the types of instructions that may be
// allowed to cross the intrinsic during instruction scheduling
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc,
int64_t maskValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier");
const char *intrinsicName = "llvm.amdgcn.sched.barrier";
LLVM::FastmathFlagsAttr defaultFlags{};

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName,
TypeRange{}, ValueRange{mask});
}

// Insert an experimental intrinsic for instruction group level parallelism.
// The intrinsic takes a value that specifies the strategy.
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt");
const char *intrinsicName = "llvm.amdgcn.iglp.opt";
LLVM::FastmathFlagsAttr defaultFlags{};
Value iglpValue =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value));
return rewriter.create<LLVM::CallIntrinsicOp>(
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags);
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName,
TypeRange{}, ValueRange{iglpValue});
}

struct InstructionSchedHintsRewriter
Expand Down
9 changes: 3 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const {

Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type,
Value cmp) const {
auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot");
SmallVector<Value> operands = {cmp};
Value asmResult =
rewriter.create<LLVM::CallIntrinsicOp>(loc, type, stringAttr, operands)
->getResult(0);
return asmResult;
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.ballot",
type, cmp)
->getResult(0);
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
Expand Down
10 changes: 4 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,9 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
auto funcName = mangleFunc(getLoadNameRaw(cm), funcType);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
auto loadVal =
rewriter
.create<LLVM::CallOp>(loc, funcOp, ValueRange({ptr, pred, falseVal}))
.getResult();
return loadVal;
return LLVM::createLLVMCallOp(rewriter, loc, funcOp,
ValueRange({ptr, pred, falseVal}))
.getResult();
}

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Expand Down Expand Up @@ -276,7 +274,7 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
auto funcName = mangleFunc(getStoreNameRaw(cm), funcType);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
rewriter.create<LLVM::CallOp>(loc, funcOp, ValueRange({ptr, val, pred}));
LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred}));
}

} // namespace mlir::LLVM::AMD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using namespace mlir::triton::gpu;

Expand Down Expand Up @@ -912,7 +913,7 @@ struct OpToExternCallConversion
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

private:
Expand Down