diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index b43d5b4b3f09..77697b6310f8 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -430,6 +430,9 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [ This operation takes and produces an optional token to indicate TMEM read and write on its accumulator operand. When the tokens are present, they can be used to check aliasing and modref on the accumulator memory. + + The `isUnsigned` attribute is only relevant when performing an integer MMA operation. + If true, the integer values are treated as unsigned, otherwise they are treated as signed. }]; let arguments = (ins @@ -442,7 +445,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [ Variadic:$barriers, Variadic:$barrier_preds, UnitAttr:$is_async, - UnitAttr:$two_ctas + UnitAttr:$two_ctas, + UnitAttr:$is_unsigned ); let results = (outs Optional:$token); @@ -452,7 +456,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [ "Value":$pred, CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers, CArg<"ValueRange", "{}">:$barrier_preds, - CArg<"bool", "false">:$is_async)> + CArg<"bool", "false">:$is_async, + CArg<"bool", "false">:$is_unsigned)> ]; let assemblyFormat = [{ diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index ac8a4b7c2c9d..c91e4ff78b7e 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" @@ -305,12 +306,15 @@ getMMAv5DTypeKindAndAcc(Type t) { {MMADTypeKind::f16, {Float16Type::get(ctx), Float32Type::get(ctx)}}}; } // TODO: float6 and explicit float4 types are not supported yet. - // TODO: tcgen05.mma supports ui8/si8 -> s32 MMA, but Triton does not. // FIXME: i8 is used to represent float4 types. - if (isa(t) || t.isInteger(8)) { + if (isa(t) && llvm::is_contained(std::array{4, 6, 8}, + t.getIntOrFloatBitWidth())) { return { {MMADTypeKind::f8f6f4, {Float16Type::get(ctx), Float32Type::get(ctx)}}}; } + if (t.isInteger(8)) { + return {{MMADTypeKind::i8, {IntegerType::get(ctx, 32)}}}; + } return std::nullopt; } @@ -404,21 +408,50 @@ void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); } void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token, Value a, Value b, Value d, Value accDep, Value useD, Value pred, bool useTwoCTAs, ValueRange barriers, - ValueRange barrierPreds, bool isAsync) { + ValueRange barrierPreds, bool isAsync, + bool isUnsigned) { if (!barriers.empty()) { isAsync = true; } build(builder, state, token, a, b, d, accDep, useD, pred, barriers, barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr(), - useTwoCTAs ? builder.getUnitAttr() : UnitAttr()); + useTwoCTAs ? builder.getUnitAttr() : UnitAttr(), + isUnsigned ? builder.getUnitAttr() : UnitAttr()); } bool TCGen5MMAOp::isAsync() { return getIsAsync(); } // -- TCGen5MMAScaledOp -- + +static Type getScaledMMAOperandType(Type elementType, + ScaleDotElemType scaleType) { + MLIRContext *ctx = elementType.getContext(); + if (isa(elementType)) + return elementType; + switch (scaleType) { + case ScaleDotElemType::E4M3: + return Float8E4M3FNType::get(ctx); + case ScaleDotElemType::E5M2: + return Float8E5M2Type::get(ctx); + case ScaleDotElemType::E2M3: + return Float6E2M3FNType::get(ctx); + case ScaleDotElemType::E3M2: + return Float6E3M2FNType::get(ctx); + case ScaleDotElemType::E2M1: + return Float4E2M1FNType::get(ctx); + case ScaleDotElemType::BF16: + return BFloat16Type::get(ctx); + case ScaleDotElemType::FP16: + return Float16Type::get(ctx); + } + llvm_unreachable("Unsupported type."); +}; + LogicalResult TCGen5MMAScaledOp::verify() { - Type atype = getA().getType().getElementType(); - Type btype = getB().getType().getElementType(); + Type atype = + getScaledMMAOperandType(getA().getType().getElementType(), getAType()); + Type btype = + getScaledMMAOperandType(getB().getType().getElementType(), getBType()); Type dtype = getD().getType().getElementType(); if (failed(verifyMMADType(*this, atype, btype, dtype))) return failure(); diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 3a1579b292b0..642f7aac312a 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -34,6 +34,82 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // ----- +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @tc_gen5_int8_unsigned_mma + // CHECK: %[[WID:.+]] = nvgpu.warp_id + // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32 + // CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1 + // CHECK: llvm.cond_br %[[P1]] + // CHECK: %[[E:.+]] = nvvm.elect.sync -> i1 + + // Verify descriptor is expected value for i8 unsigned. + // CHECK:llvm.mlir.constant(136314912 : i32) : i32 + + // CHECK-COUNT-4: @$5 tcgen05.mma.cta_group::1.kind::i8 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]] + // CHECK: %[[PRED:.+]] = llvm.and %arg6, %[[E]] + // CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[PRED]] + tt.func @tc_gen5_int8_unsigned_mma(%a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>, + %barrierPred: i1) { + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, is_unsigned} : + !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @tc_gen5_int8_signed_mma + // CHECK: %[[WID:.+]] = nvgpu.warp_id + // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32 + // CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1 + // CHECK: llvm.cond_br %[[P1]] + // CHECK: %[[E:.+]] = nvvm.elect.sync -> i1 + + // Verify descriptor is expected value for i8 signed. + // CHECK:llvm.mlir.constant(136316064 : i32) : i32 + + // CHECK-COUNT-4: @$5 tcgen05.mma.cta_group::1.kind::i8 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]] + // CHECK: %[[PRED:.+]] = llvm.and %arg6, %[[E]] + // CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[PRED]] + tt.func @tc_gen5_int8_signed_mma(%a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>, + %barrierPred: i1) { + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} : + !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x128xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable>, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory> + tt.return + } +} + +// ----- + #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir index dcff6b84a13f..4ef937d66c0d 100644 --- a/test/TritonNvidiaGPU/ops.mlir +++ b/test/TritonNvidiaGPU/ops.mlir @@ -34,6 +34,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } + // CHECK-LABEL: @tcgen5_int8 + // CHECK: ttng.tc_gen5_mma {{.*}} {is_async, is_unsigned} + // CHECK: ttng.tc_gen5_mma {{.*}} {is_unsigned} + tt.func @tcgen5_int8( + %a: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable>, + %accUse: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, + %barrierPred: i1) { + ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%barrierPred] {is_async, is_unsigned} : + !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable>, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> + + ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_unsigned}: + !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x256xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xi32, #shared1, #ttg.shared_memory, mutable> + tt.return + } + // CHECK-LABEL: @async_tma_gather // CHECK-SAME: [[DESC:%arg[0-9]+]]: // CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]: diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index f50afe2fb5d6..800eea70ba42 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -126,7 +126,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, uint32_t shift : 2; }; }; - auto getTypeEncoding = [](Type type) { + auto getTypeEncoding = [&](Type type) { if (type.isF16()) return 0; if (type.isBF16()) @@ -137,6 +137,10 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 0; if (llvm::isa(type)) return 1; + // For 8-bit integer types, signed arithmetic is 1, unsigned arithmetic is + // 0. + if (type.isInteger(8)) + return op.getIsUnsigned() ? 0 : 1; llvm_unreachable("Unsupported type."); }; static_assert(sizeof(TCGen5InstructionDescriptor) == 4, @@ -150,8 +154,12 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, desc.aType = getTypeEncoding(op.getA().getType().getElementType()); desc.bType = getTypeEncoding(op.getB().getType().getElementType()); Type dstElType = op.getD().getType().getElementType(); - assert(dstElType.isF16() || dstElType.isF32()); - desc.dType = dstElType.isF16() ? 0 : 1; + assert(dstElType.isF16() || dstElType.isF32() || dstElType.isInteger(32)); + if (dstElType.isInteger(32)) { + desc.dType = 2; + } else { + desc.dType = dstElType.isF16() ? 0 : 1; + } return b.int_val(32, desc.descriptor); } @@ -258,14 +266,19 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, std::string opcode = "tcgen05.mma.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + ".kind::"; Type srcElementTy = op.getA().getType().getElementType(); - if (srcElementTy.isF16() || srcElementTy.isBF16()) + if (srcElementTy.isF16() || srcElementTy.isBF16()) { opcode += "f16"; - else if (srcElementTy.isF32()) + } else if (srcElementTy.isF32()) { opcode += "tf32"; - else if (llvm::isa(srcElementTy)) + } else if (llvm::isa(srcElementTy)) { opcode += "f8f6f4"; - else + } else if (op.getD().getType().getElementType().isInteger(32)) { + // PTX uses "i8" for integer operations (both signed and unsigned) + // The signed/unsigned distinction is encoded in the instruction descriptor + opcode += "i8"; + } else { assert(0 && "Unsupported type."); + } auto *accOp = ptxBuilder.newAddrOperand(d.base, "r", *d.offset); assert(a.offset.has_value() == aInTMem); auto *aOp = aInTMem ? ptxBuilder.newAddrOperand(a.base, "r", *a.offset)