Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
This operation takes and produces an optional token to indicate TMEM read
and write on its accumulator operand. When the tokens are present, they can
be used to check aliasing and modref on the accumulator memory.

The `isUnsigned` attribute is only relevant when performing an integer MMA operation.
If true, the integer values are treated as unsigned, otherwise they are treated as signed.
}];

let arguments = (ins
Expand All @@ -442,7 +445,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
Variadic<TTG_MemDescType>:$barriers,
Variadic<I1>:$barrier_preds,
UnitAttr:$is_async,
UnitAttr:$two_ctas
UnitAttr:$two_ctas,
UnitAttr:$is_unsigned
);
let results = (outs Optional<TTG_AsyncToken>:$token);

Expand All @@ -452,7 +456,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
"Value":$pred, CArg<"bool", "false">:$two_ctas,
CArg<"ValueRange", "{}">:$barriers,
CArg<"ValueRange", "{}">:$barrier_preds,
CArg<"bool", "false">:$is_async)>
CArg<"bool", "false">:$is_async,
CArg<"bool", "false">:$is_unsigned)>
];

let assemblyFormat = [{
Expand Down
45 changes: 39 additions & 6 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
Expand Down Expand Up @@ -305,12 +306,15 @@ getMMAv5DTypeKindAndAcc(Type t) {
{MMADTypeKind::f16, {Float16Type::get(ctx), Float32Type::get(ctx)}}};
}
// TODO: float6 and explicit float4 types are not supported yet.
// TODO: tcgen05.mma supports ui8/si8 -> s32 MMA, but Triton does not.
// FIXME: i8 is used to represent float4 types.
if (isa<Float8E4M3FNType, Float8E5M2Type>(t) || t.isInteger(8)) {
if (isa<FloatType>(t) && llvm::is_contained(std::array<unsigned, 3>{4, 6, 8},
t.getIntOrFloatBitWidth())) {
return {
{MMADTypeKind::f8f6f4, {Float16Type::get(ctx), Float32Type::get(ctx)}}};
}
if (t.isInteger(8)) {
return {{MMADTypeKind::i8, {IntegerType::get(ctx, 32)}}};
}
return std::nullopt;
}

Expand Down Expand Up @@ -404,21 +408,50 @@ void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); }
void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token,
Value a, Value b, Value d, Value accDep, Value useD,
Value pred, bool useTwoCTAs, ValueRange barriers,
ValueRange barrierPreds, bool isAsync) {
ValueRange barrierPreds, bool isAsync,
bool isUnsigned) {
if (!barriers.empty()) {
isAsync = true;
}
build(builder, state, token, a, b, d, accDep, useD, pred, barriers,
barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr(),
useTwoCTAs ? builder.getUnitAttr() : UnitAttr());
useTwoCTAs ? builder.getUnitAttr() : UnitAttr(),
isUnsigned ? builder.getUnitAttr() : UnitAttr());
}

bool TCGen5MMAOp::isAsync() { return getIsAsync(); }

// -- TCGen5MMAScaledOp --

static Type getScaledMMAOperandType(Type elementType,
ScaleDotElemType scaleType) {
MLIRContext *ctx = elementType.getContext();
if (isa<FloatType>(elementType))
return elementType;
switch (scaleType) {
case ScaleDotElemType::E4M3:
return Float8E4M3FNType::get(ctx);
case ScaleDotElemType::E5M2:
return Float8E5M2Type::get(ctx);
case ScaleDotElemType::E2M3:
return Float6E2M3FNType::get(ctx);
case ScaleDotElemType::E3M2:
return Float6E3M2FNType::get(ctx);
case ScaleDotElemType::E2M1:
return Float4E2M1FNType::get(ctx);
case ScaleDotElemType::BF16:
return BFloat16Type::get(ctx);
case ScaleDotElemType::FP16:
return Float16Type::get(ctx);
}
llvm_unreachable("Unsupported type.");
};

LogicalResult TCGen5MMAScaledOp::verify() {
Type atype = getA().getType().getElementType();
Type btype = getB().getType().getElementType();
Type atype =
getScaledMMAOperandType(getA().getType().getElementType(), getAType());
Type btype =
getScaledMMAOperandType(getB().getType().getElementType(), getBType());
Type dtype = getD().getType().getElementType();
if (failed(verifyMMADType(*this, atype, btype, dtype)))
return failure();
Expand Down
76 changes: 76 additions & 0 deletions test/Conversion/tritongpu_to_llvm_blackwell.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,82 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
// CHECK-LABEL: @tc_gen5_int8_unsigned_mma
// CHECK: %[[WID:.+]] = nvgpu.warp_id
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
// CHECK: llvm.cond_br %[[P1]]
// CHECK: %[[E:.+]] = nvvm.elect.sync -> i1

// Verify descriptor is expected value for i8 unsigned.
// CHECK:llvm.mlir.constant(136314912 : i32) : i32

// CHECK-COUNT-4: @$5 tcgen05.mma.cta_group::1.kind::i8 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]]
// CHECK: %[[PRED:.+]] = llvm.and %arg6, %[[E]]
// CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[PRED]]
tt.func @tc_gen5_int8_unsigned_mma(%a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
%b: !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>,
%c: !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>,
%useAcc: i1,
%pred: i1,
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
%barrierPred: i1) {
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, is_unsigned} :
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
!ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>,
!ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>,
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>
tt.return
}
}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
// CHECK-LABEL: @tc_gen5_int8_signed_mma
// CHECK: %[[WID:.+]] = nvgpu.warp_id
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
// CHECK: llvm.cond_br %[[P1]]
// CHECK: %[[E:.+]] = nvvm.elect.sync -> i1

// Verify descriptor is expected value for i8 signed.
// CHECK:llvm.mlir.constant(136316064 : i32) : i32

// CHECK-COUNT-4: @$5 tcgen05.mma.cta_group::1.kind::i8 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]]
// CHECK: %[[PRED:.+]] = llvm.and %arg6, %[[E]]
// CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[PRED]]
tt.func @tc_gen5_int8_signed_mma(%a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
%b: !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>,
%c: !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>,
%useAcc: i1,
%pred: i1,
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
%barrierPred: i1) {
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} :
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
!ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>,
!ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>,
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>
tt.return
}
}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}>
Expand Down
24 changes: 24 additions & 0 deletions test/TritonNvidiaGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.return
}

// CHECK-LABEL: @tcgen5_int8
// CHECK: ttng.tc_gen5_mma {{.*}} {is_async, is_unsigned}
// CHECK: ttng.tc_gen5_mma {{.*}} {is_unsigned}
tt.func @tcgen5_int8(
%a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
%b: !ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>,
%c: !ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable>,
%accUse: i1,
%pred: i1,
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>,
%barrierPred: i1) {
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%barrierPred] {is_async, is_unsigned} :
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
!ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>,
!ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable>,
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>

ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_unsigned}:
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
!ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>,
!ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable>
tt.return
}

// CHECK-LABEL: @async_tma_gather
// CHECK-SAME: [[DESC:%arg[0-9]+]]:
// CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
uint32_t shift : 2;
};
};
auto getTypeEncoding = [](Type type) {
auto getTypeEncoding = [&](Type type) {
if (type.isF16())
return 0;
if (type.isBF16())
Expand All @@ -137,6 +137,10 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
return 0;
if (llvm::isa<Float8E5M2Type>(type))
return 1;
// For 8-bit integer types, signed arithmetic is 1, unsigned arithmetic is
// 0.
if (type.isInteger(8))
return op.getIsUnsigned() ? 0 : 1;
llvm_unreachable("Unsupported type.");
};
static_assert(sizeof(TCGen5InstructionDescriptor) == 4,
Expand All @@ -150,8 +154,12 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
desc.aType = getTypeEncoding(op.getA().getType().getElementType());
desc.bType = getTypeEncoding(op.getB().getType().getElementType());
Type dstElType = op.getD().getType().getElementType();
assert(dstElType.isF16() || dstElType.isF32());
desc.dType = dstElType.isF16() ? 0 : 1;
assert(dstElType.isF16() || dstElType.isF32() || dstElType.isInteger(32));
if (dstElType.isInteger(32)) {
desc.dType = 2;
} else {
desc.dType = dstElType.isF16() ? 0 : 1;
}
return b.int_val(32, desc.descriptor);
}

Expand Down Expand Up @@ -258,14 +266,19 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
std::string opcode =
"tcgen05.mma.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + ".kind::";
Type srcElementTy = op.getA().getType().getElementType();
if (srcElementTy.isF16() || srcElementTy.isBF16())
if (srcElementTy.isF16() || srcElementTy.isBF16()) {
opcode += "f16";
else if (srcElementTy.isF32())
} else if (srcElementTy.isF32()) {
opcode += "tf32";
else if (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementTy))
} else if (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementTy)) {
opcode += "f8f6f4";
else
} else if (op.getD().getType().getElementType().isInteger(32)) {
// PTX uses "i8" for integer operations (both signed and unsigned)
// The signed/unsigned distinction is encoded in the instruction descriptor
opcode += "i8";
} else {
assert(0 && "Unsupported type.");
}
auto *accOp = ptxBuilder.newAddrOperand(d.base, "r", *d.offset);
assert(a.offset.has_value() == aInTMem);
auto *aOp = aInTMem ? ptxBuilder.newAddrOperand(a.base, "r", *a.offset)
Expand Down
Loading