Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 342 additions & 182 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto ctx = rewriter.getContext();
auto typeConverter = getTypeConverter();
auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter);
auto elemTy = elems[0].getType();
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0);
for (auto elem : elems) {
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
condition =
or_(condition,
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, elemTy, rewriter.getZeroAttr(elemTy))));
condition = b.or_(
condition,
b.icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, elemTy, rewriter.getZeroAttr(elemTy))));
} else {
assert(false && "Unsupported type for assert");
return failure();
Expand All @@ -41,7 +42,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
// tensor in those two operations may have different layout we need to
// make sure all the threads are done executing the assert before going to
// the next op.
barrier();
b.barrier();
}
rewriter.eraseOp(op);
return success();
Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (funcOp->hasAttr("nvvm.kernel")) {
// A GPU kernel
if (op.getNumOperands() > 0) {
Expand All @@ -34,10 +36,9 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
funcOp.getResultTypes());
Value packedResults =
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
auto loc = op.getLoc();
for (auto it : llvm::enumerate(adaptor.getOperands())) {
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
it.index());
packedResults = b.insert_val(packedResultsTy, packedResults,
it.value(), it.index());
}
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
}
Expand Down Expand Up @@ -78,6 +79,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
// Get the last argument of the caller, which is the current stack pointer
// of shared memory and append it to the operands of the callOp.
auto loc = callOp.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
Expand All @@ -95,7 +97,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
Value opOffsetVal;
if (opOffsetAttr) {
auto opOffset = opOffsetAttr.getValue().getZExtValue();
opOffsetVal = i32_val(opOffset);
opOffsetVal = b.i32_val(opOffset);
}

promotedOperands.push_back(
Expand Down
94 changes: 50 additions & 44 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct ConvertLayoutOpConversion
ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto rank = type.getRank();
Expand Down Expand Up @@ -110,29 +111,29 @@ struct ConvertLayoutOpConversion
Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped,
paddedRepShape, outOrd);
auto elemPtrTy = smemBase.getType();
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
if (stNotRd) {
Value valVec = undef(vecTy);
Value valVec = b.undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
currVal = b.zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
currVal = b.ptrtoint(llvmElemTy, currVal);
valVec = b.insert_element(vecTy, valVec, currVal, b.i32_val(v));
}
store(valVec, ptr);
b.store(valVec, ptr);
} else {
Value valVec = load(vecTy, ptr);
Value valVec = b.load(vecTy, ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
Value currVal = b.extract_element(llvmElemTy, valVec, b.i32_val(v));
if (isInt1)
currVal = icmp_ne(currVal,
rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
currVal = b.icmp_ne(
currVal, rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
else if (isPtr)
currVal = inttoptr(llvmElemTyOrig, currVal);
currVal = b.inttoptr(llvmElemTyOrig, currVal);
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
Expand All @@ -146,6 +147,7 @@ struct ConvertLayoutOpConversion
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) const {
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto typeConverter = getTypeConverter();
RankedTensorType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Expand Down Expand Up @@ -205,12 +207,12 @@ struct ConvertLayoutOpConversion
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0) {
barrier();
b.barrier();
}
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
vals, smemBase);
barrier();
b.barrier();
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, origRepShape,
outOrd, outVals, smemBase);
Expand Down Expand Up @@ -355,6 +357,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();

Expand Down Expand Up @@ -399,9 +402,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Munge input values
for (const auto &it : llvm::enumerate(inVals)) {
if (isSubByteInt) {
inVals[it.index()] = zext(llvmElemTy, it.value());
inVals[it.index()] = b.zext(llvmElemTy, it.value());
} else if (isPtr) {
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value());
}
}

Expand All @@ -417,9 +420,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Unmunge output values
for (const auto &it : llvm::enumerate(outVals)) {
if (isSubByteInt) {
outVals[it.index()] = trunc(llvmElemTyOrig, it.value());
outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value());
} else if (isPtr) {
outVals[it.index()] = inttoptr(llvmElemTyOrig, it.value());
outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value());
}
}

Expand All @@ -443,6 +446,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);

StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
Expand All @@ -452,9 +456,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
StringAttr kIteration = str_attr("iteration");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value threadsPerWarp = b.i32_val(srcLayout.getInDimSize(kLane));
Value laneId = b.urem(threadId, threadsPerWarp);
Value warpId = b.udiv(threadId, threadsPerWarp);

auto scratchConfig =
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
Expand Down Expand Up @@ -541,37 +545,38 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
{kWarp, 0},
{kBlock, 0}})[0]
.second;
Value offset = xor_(regBase, i32_val(regIdx));
Value offset = b.xor_(regBase, b.i32_val(regIdx));
if (paddedSize > 0) {
assert(llvm::isPowerOf2_32(paddedStride));
assert(llvm::isPowerOf2_32(paddedSize));
auto rshiftVal = llvm::Log2_32(paddedStride);
auto lshiftVal = llvm::Log2_32(paddedSize);
offset = add(shl(lshr(offset, i32_val(rshiftVal)), i32_val(lshiftVal)),
offset);
offset = b.add(
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
offset);
}
auto vecAddr = gep(sharedPtrTy, elemTy, smemBase, offset);
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset);
vecAddr.setInbounds(true);
return vecAddr;
};

auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
{{kRegister, i32_val(0)},
{{kRegister, b.i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, i32_val(0)}})[0]
{kBlock, b.i32_val(0)}})[0]
.second;
auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout,
{{kRegister, i32_val(0)},
{{kRegister, b.i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, i32_val(0)}})[0]
{kBlock, b.i32_val(0)}})[0]
.second;
// register idx -> Value
llvm::MapVector<int, Value> outVals;
for (int i = 0; i < iterations; i++) {
if (i != 0)
barrier();
b.barrier();

auto &inRegs = inRegsForIter[i];
auto &outRegs = outRegsForIter[i];
Expand All @@ -591,19 +596,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
} else {
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
/*pred=*/true_val());
/*pred=*/b.true_val());
}
}

barrier();
b.barrier();

for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) {
auto outRegSlice = outRegs[j];
auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice);
Value valsVec =
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
vec_ty(elemTy, scratchConfig.outVec),
/*pred=*/true_val());
/*pred=*/b.true_val());
for (Value v : unpackLLVector(loc, valsVec, rewriter))
outVals[outRegSlice++] = v;
}
Expand Down Expand Up @@ -646,6 +651,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
Expand All @@ -657,8 +663,8 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
SmallVector<Value> shflOuts(Cp.getInDimSize(kRegister));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(Cp.getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value threadsPerWarp = b.i32_val(Cp.getInDimSize(kLane));
Value laneId = b.urem(threadId, threadsPerWarp);

// Emit one shuffle per destination register.
for (int i : llvm::seq(shflOuts.size())) {
Expand All @@ -667,22 +673,22 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
// At the same time, for each register, P1 returns the source value index
// to provide as the shuffle value.
auto out = applyLinearLayout(loc, rewriter, P1,
{{kLane, laneId}, {kRegister, i32_val(i)}});
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
assert(out.size() == 1);
Value srcRegIdx = out.front().second;
// The size of the input lane dimension is the number of selects to emit.
// TODO(jeff): For dtypes smaller than i32, we can use byte permutes and
// shuffle multiple values at a time.
Value shflSrc = undef(srcValues.front().getType());
Value shflSrc = b.undef(srcValues.front().getType());
for (int j : llvm::seq(reducedP1.getInDimSize(kLane))) {
int32_t check =
reducedP1.apply({{kLane, j}, {kRegister, i}}).front().second;
shflSrc =
select(icmp_eq(srcRegIdx, i32_val(check)), srcValues[check], shflSrc);
shflSrc = b.select(b.icmp_eq(srcRegIdx, b.i32_val(check)),
srcValues[check], shflSrc);
}

out = applyLinearLayout(loc, rewriter, Cp,
{{kLane, laneId}, {kRegister, i32_val(i)}});
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
assert(out.size() == 1);
Value shflIdx = out.front().second;
shflOuts[i] = targetInfo.shuffleIdx(rewriter, loc, shflSrc, shflIdx);
Expand All @@ -693,16 +699,16 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
// selects.
SmallVector<Value> results(shflOuts.size());
for (int i : llvm::seq(results.size())) {
Value result = undef(srcValues.front().getType());
Value result = b.undef(srcValues.front().getType());

auto out = applyLinearLayout(loc, rewriter, P2inv,
{{kLane, laneId}, {kRegister, i32_val(i)}});
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
Value resultIdx = out.front().second;
for (int j : llvm::seq(reducedP2.getInDimSize(kLane))) {
int32_t check =
reducedP2.apply({{kLane, j}, {kRegister, i}}).front().second;
result =
select(icmp_eq(resultIdx, i32_val(check)), shflOuts[check], result);
result = b.select(b.icmp_eq(resultIdx, b.i32_val(check)), shflOuts[check],
result);
}
results[i] = result;
}
Expand Down
Loading
Loading