Skip to content

Commit 40c1838

Browse files
authored
[AMD] Support scale is none in DotScaledOp for gfx950 (#5931)
This PR supported the case when one or two scales are None in DotScaledOp in gfx950. If scale is None, a constant scale tensor with value 1.0 will be created.
1 parent 73a724f commit 40c1838

File tree

5 files changed

+185
-39
lines changed

5 files changed

+185
-39
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,10 @@ def block_scale_fp4_matmul( #
682682
# Two e2m1 values per K
683683
offs_k = tl.arange(0, BLOCK_K // 2)
684684
offs_scale_k = tl.arange(0, BLOCK_K // VEC_SIZE)
685-
a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :]
686-
b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :]
685+
if a_scale is not None:
686+
a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :]
687+
if b_scale is not None:
688+
b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :]
687689
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
688690
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
689691
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty)
@@ -692,13 +694,21 @@ def block_scale_fp4_matmul( #
692694
valid_k = offs_k < k_remaining
693695
a = tl.load(a_ptrs, mask=valid_k[None, :], other=0)
694696
b = tl.load(b_ptrs, mask=valid_k[:, None], other=0)
695-
scale_a = tl.load(a_scale_ptr)
696-
scale_b = tl.load(b_scale_ptr)
697+
if a_scale is not None:
698+
scale_a = tl.load(a_scale_ptr)
699+
else:
700+
scale_a = None
701+
if b_scale is not None:
702+
scale_b = tl.load(b_scale_ptr)
703+
else:
704+
scale_b = None
697705
accumulator = tl.dot_scaled(a, scale_a, "e2m1", b, scale_b, "e2m1", accumulator)
698706
a_ptrs += (BLOCK_K // 2) * stride_ak
699707
b_ptrs += (BLOCK_K // 2) * stride_bk
700-
a_scale_ptr += BLOCK_K // VEC_SIZE
701-
b_scale_ptr += BLOCK_K // VEC_SIZE
708+
if a_scale is not None:
709+
a_scale_ptr += BLOCK_K // VEC_SIZE
710+
if b_scale is not None:
711+
b_scale_ptr += BLOCK_K // VEC_SIZE
702712
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
703713
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
704714
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
@@ -709,12 +719,18 @@ def block_scale_fp4_matmul( #
709719
@pytest.mark.parametrize("M, N, K", [(1024, 512, 256), (2, 4, 64)])
710720
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
711721
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
722+
@pytest.mark.parametrize("with_a_scale", [True, False])
723+
@pytest.mark.parametrize("with_b_scale", [True, False])
712724
@pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)],
713725
ids=["mxfp4", "nvfp4"])
714726
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else []))
715-
def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, nonKDim, device):
716-
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
717-
pytest.skip("Requires compute capability >= 10")
727+
def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, with_a_scale, with_b_scale, scale_type, nonKDim,
728+
device):
729+
if is_cuda():
730+
if torch.cuda.get_device_capability()[0] < 10:
731+
pytest.skip("Requires compute capability >= 10")
732+
if not (with_a_scale and with_b_scale):
733+
pytest.skip("None aScale/bScale is only tested on AMD backend for now")
718734
elif is_hip():
719735
if not is_hip_mi350():
720736
pytest.skip("Scaled fp4 matmul is only natively supported on MI350")
@@ -750,14 +766,21 @@ def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_typ
750766

751767
a_scale_ref = a_scale_ref.to(torch.float32).repeat_interleave(VEC_SIZE, dim=1)[:M, :K]
752768
b_scale_ref = b_scale_ref.to(torch.float32).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N]
769+
stride_scale = a_scale.stride(0)
770+
if not with_a_scale:
771+
a_scale = None
772+
a_scale_ref = 1.0
773+
if not with_b_scale:
774+
b_scale = None
775+
b_scale_ref = 1.0
753776
ref_out = torch.matmul(a_mxfp4.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref)
754777

755778
output = a.new_empty((M, N), dtype=torch.float32)
756779
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
757780
kernel_kwargs = {}
758781
if is_hip():
759782
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
760-
block_scale_fp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1),
783+
block_scale_fp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
761784
b.stride(0), b.stride(1), output.stride(0), output.stride(1), VEC_SIZE, BLOCK_M,
762785
BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, **kernel_kwargs)
763786

python/triton/language/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,8 +1649,8 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te
16491649
allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
16501650
assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
16511651
assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
1652-
rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
1653-
lhs_scale_is_none = isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None
1652+
rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
1653+
lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
16541654
lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
16551655
rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)
16561656

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx950 matrix-instruction-size=0' | FileCheck %s --check-prefixes CHECK
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
// CHECK{LITERAL}: #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}>
5+
// CHECK{LITERAL}: #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}>
6+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_mxfp4
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
8+
tt.func public @mfma_dot_scaled_mxfp4_mxfp4(
9+
%arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
10+
%arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
11+
%arg2: tensor<128x4xi8>,
12+
%arg3: tensor<128x4xi8>,
13+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
14+
) {
15+
// CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear>
16+
// CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear1>
17+
// CHECK-NOT: tt.fp_to_fp
18+
// CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
19+
// CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
20+
// CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
21+
// CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear>
22+
// CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1>
23+
// CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e2m1 rhs = e2m1
24+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
25+
%1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, tensor<128x4xi8> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<128x4xi8> -> tensor<128x128xf32, #blocked>
26+
tt.store %arg4, %1 : tensor<128x128x!tt.ptr<f32>, #blocked>
27+
tt.return
28+
}
29+
}
30+
31+
// -----
32+
33+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
34+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_fp4
35+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
36+
tt.func public @mfma_dot_scaled_mxfp4_fp4(
37+
%arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
38+
%arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
39+
%arg2: tensor<128x4xi8>,
40+
%arg3: tensor<128x128x!tt.ptr<f32>, #blocked>
41+
) {
42+
// CHECK-NOT: tt.fp_to_fp
43+
// CHECK: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear>
44+
// CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1>
45+
// CHECK: tt.dot_scaled {{.*}} scale %[[SCALE0]], {{.*}} scale %[[CST1]], {{.*}} lhs = e2m1 rhs = e2m1
46+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
47+
%1 = tt.dot_scaled %arg0 scale %arg2, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, tensor<128x4xi8> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
48+
tt.store %arg3, %1 : tensor<128x128x!tt.ptr<f32>, #blocked>
49+
tt.return
50+
}
51+
}
52+
53+
// -----
54+
55+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
56+
// CHECK-LABEL: mfma_dot_scaled_fp4_mxfp4
57+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
58+
tt.func public @mfma_dot_scaled_fp4_mxfp4(
59+
%arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
60+
%arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
61+
%arg2: tensor<128x4xi8>,
62+
%arg3: tensor<128x128x!tt.ptr<f32>, #blocked>
63+
) {
64+
// CHECK-NOT: tt.fp_to_fp
65+
// CHECK: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear>
66+
// CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1>
67+
// CHECK: tt.dot_scaled {{.*}} scale %[[CST0]], {{.*}} scale %[[SCALE1]], {{.*}} lhs = e2m1 rhs = e2m1
68+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
69+
%1 = tt.dot_scaled %arg0, %arg1 scale %arg2, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<128x4xi8> -> tensor<128x128xf32, #blocked>
70+
tt.store %arg3, %1 : tensor<128x128x!tt.ptr<f32>, #blocked>
71+
tt.return
72+
}
73+
}
74+
75+
// -----
76+
77+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
78+
// CHECK-LABEL: mfma_dot_scaled_fp4_fp4
79+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
80+
tt.func public @mfma_dot_scaled_fp4_fp4(
81+
%arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
82+
%arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
83+
%arg2: tensor<128x128x!tt.ptr<f32>, #blocked>
84+
) {
85+
// CHECK-NOT: tt.fp_to_fp
86+
// CHECK-DAG: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear>
87+
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear1>
88+
// CHECK: tt.dot_scaled {{.*}} scale %[[CST1]], {{.*}} scale %[[CST0]], {{.*}} lhs = e2m1 rhs = e2m1
89+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
90+
%1 = tt.dot_scaled %arg0, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
91+
tt.store %arg2, %1 : tensor<128x128x!tt.ptr<f32>, #blocked>
92+
tt.return
93+
}
94+
}

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ struct DotOpMFMAConversionHelper {
372372
/// rawElems is a vector of kWidth elements. We need to prepare vector(s) of
373373
/// kBase elements for each mfma instruction
374374
SmallVector<Value> extractOperands(Value rawElems, int kWidth, int kBase,
375-
Type type, bool preserveBF16) const {
375+
Type type, bool preserveBF16,
376+
bool isConstantScale = false) const {
376377
auto b = TritonLLVMOpBuilder(loc, rewriter);
377378
int kpack = kWidth / kBase;
378379
SmallVector<Value> results;
@@ -393,9 +394,20 @@ struct DotOpMFMAConversionHelper {
393394
}
394395
}
395396
if (type.getIntOrFloatBitWidth() == 8) {
396-
if (1 == kBase)
397+
if (1 == kBase) {
397398
// This is only for the scale operands of scaled mfma on MI350
398-
results.push_back(b.zext(i32_ty, b.bitcast(vec, i8_ty)));
399+
if (isConstantScale) {
400+
// If the scale is constant(created by arith::ConstantOp), it will
401+
// be put in a sgpr instead of vgpr. In that case, instead of
402+
// vgpr[7:0], the instruction reads sgpr[30:23] as the scale value.
403+
// So we need to manually left shift the scale by 23 bits to meet
404+
// the requirement.
405+
results.push_back(b.shl(
406+
i32_ty, b.zext(i32_ty, b.bitcast(vec, i8_ty)), b.i32_val(23)));
407+
} else {
408+
results.push_back(b.zext(i32_ty, b.bitcast(vec, i8_ty)));
409+
}
410+
}
399411
if (4 == kBase)
400412
// This is for int8 on pre- MI300 GPUs
401413
results.push_back(b.bitcast(vec, i32_ty));
@@ -413,10 +425,9 @@ struct DotOpMFMAConversionHelper {
413425

414426
/// Converts dot operand structure to value table and converts types
415427
/// appropriate for mfma instructions
416-
virtual SmallVector<ValueTable>
417-
getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1,
418-
int kWidth, int kBase, Type type,
419-
bool allowXF32, bool preserveBF16) const {
428+
virtual SmallVector<ValueTable> getValuesFromDotOperandLayoutStruct(
429+
Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type,
430+
bool allowXF32, bool preserveBF16, bool isConstantScale = false) const {
420431
auto tb = TritonLLVMOpBuilder(loc, rewriter);
421432
auto elems = unpackLLElements(loc, value, rewriter);
422433
int kpack = kWidth / kBase;
@@ -445,8 +456,8 @@ struct DotOpMFMAConversionHelper {
445456
vals = extractOperands(rawElems, kWidth, kBase, f32_ty,
446457
preserveBF16);
447458
} else if (type.getIntOrFloatBitWidth() == 8) {
448-
vals =
449-
extractOperands(rawElems, kWidth, kBase, i8_ty, preserveBF16);
459+
vals = extractOperands(rawElems, kWidth, kBase, i8_ty,
460+
preserveBF16, isConstantScale);
450461
} else if (type.isBF16()) {
451462
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty,
452463
preserveBF16);
@@ -506,6 +517,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
506517
Value b = op.getRhs();
507518
Value aScale = op.getLhsScale();
508519
Value bScale = op.getRhsScale();
520+
bool isAScaleConstant = aScale.getDefiningOp<arith::ConstantOp>();
521+
bool isBScaleConstant = bScale.getDefiningOp<arith::ConstantOp>();
509522
Value d = op.getD();
510523
auto aTensorTy = cast<RankedTensorType>(a.getType());
511524
auto bTensorTy = cast<RankedTensorType>(b.getType());
@@ -581,10 +594,12 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
581594
// operands.
582595
auto operandAScale = getValuesFromDotOperandLayoutStruct(
583596
loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase,
584-
aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
597+
aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false,
598+
isAScaleConstant);
585599
auto operandBScale = getValuesFromDotOperandLayoutStruct(
586600
loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase,
587-
bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
601+
bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false,
602+
isBScaleConstant);
588603

589604
auto dstElemTy = dTensorTy.getElementType();
590605
auto fc = unpackLLElements(loc, loadedC, rewriter);

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -724,10 +724,6 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
724724
TensorValue bScale = dotOp.getRhsScale();
725725
auto oldShape = oldRetType.getShape();
726726

727-
if (!aScale || !bScale)
728-
return rewriter.notifyMatchFailure(dotOp,
729-
"expect scales for both A and B");
730-
731727
ScaleDotElemType aElemType = dotOp.getLhsType();
732728
ScaleDotElemType bElemType = dotOp.getRhsType();
733729
auto supportsTypes = [](ScaleDotElemType elemType) {
@@ -872,14 +868,25 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
872868

873869
auto convertScaleLayout = [&](TensorValue val, TensorValue scale,
874870
DotOperandEncodingAttr enc,
875-
int idx) -> TensorValue {
876-
auto dotLL = enc.toLinearLayout(val.getType().getShape());
871+
int idx) -> Value {
872+
auto valShape = val.getType().getShape();
873+
874+
auto dotLL = enc.toLinearLayout(valShape);
877875
LinearLayout::BasesT scaleBases = dotLL.getBases();
878876
auto &warpBases = scaleBases[kWarp];
879877

880878
LinearLayout newLL = createLinearLayout(idx, warpBases);
881879

882-
auto shape = scale.getType().getShape();
880+
SmallVector<int64_t> shape;
881+
if (!scale) {
882+
int64_t nonKDim = idx == 0 ? valShape[0] : valShape[1];
883+
int64_t k = idx == 0 ? valShape[1] : valShape[0];
884+
ScaleDotElemType &elemType = idx == 0 ? aElemType : bElemType;
885+
int packSize = elemType == ScaleDotElemType::E2M1 ? 2 : 1;
886+
shape = {nonKDim, k * packSize / 32};
887+
} else {
888+
shape = llvm::to_vector(scale.getType().getShape());
889+
}
883890

884891
// Adjust register-level layout to fill the shape, at this level, both
885892
// aScale and bScale should align with A operand.
@@ -891,18 +898,25 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
891898
}
892899
newLL = newLL.transposeOuts(standardOutDims);
893900
Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL);
894-
895-
auto newScaleType = RankedTensorType::get(
896-
shape, scale.getType().getElementType(), newScaleEncoding);
897-
return rewriter.create<ttg::ConvertLayoutOp>(scale.getLoc(), newScaleType,
898-
scale);
901+
// Scale's data type is always i8
902+
auto newScaleType = RankedTensorType::get(shape, i8_ty, newScaleEncoding);
903+
904+
if (!scale) {
905+
// 0x7F is 1.0 in E8M0
906+
return rewriter.create<arith::ConstantOp>(
907+
dotOp->getLoc(), newScaleType,
908+
DenseElementsAttr::get(newScaleType, llvm::APInt(8, 0x7F)));
909+
} else {
910+
return rewriter.create<ttg::ConvertLayoutOp>(scale.getLoc(),
911+
newScaleType, scale);
912+
}
899913
};
900-
aScale = convertScaleLayout(a, aScale, newAEncoding, 0);
901-
bScale = convertScaleLayout(b, bScale, newBEncoding, 1);
914+
auto newAScale = convertScaleLayout(a, aScale, newAEncoding, 0);
915+
auto newBScale = convertScaleLayout(b, bScale, newBEncoding, 1);
902916

903917
auto newDot = rewriter.create<triton::DotScaledOp>(
904-
dotOp.getLoc(), newRetType, a, b, newAcc, aScale, bScale, aElemType,
905-
bElemType, dotOp.getFastMath());
918+
dotOp.getLoc(), newRetType, a, b, newAcc, newAScale, newBScale,
919+
aElemType, bElemType, dotOp.getFastMath());
906920

907921
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
908922
newDot);

0 commit comments

Comments
 (0)