Skip to content

Commit 3516eac

Browse files
authored
[LAYOUTS] Fix backwards propagation for Fp4ToFp (#8438)
We also tighten its verifier.
1 parent 855ca6c commit 3516eac

File tree

4 files changed

+73
-2
lines changed

4 files changed

+73
-2
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,37 @@ LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op,
442442
<< ", dst=" << resShape[i] << ", axis=" << axis << ")";
443443
}
444444
}
445+
if (bool(resTy.getEncoding()) != bool(srcTy.getEncoding()))
446+
return op->emitError()
447+
<< "source and result must both have an encoding, or neither";
448+
if (!resTy.getEncoding()) {
449+
return success();
450+
}
451+
auto srcLl = toLinearLayout(srcTy);
452+
auto resLl = toLinearLayout(resTy);
453+
auto *ctx = srcTy.getContext();
454+
auto regDim = StringAttr::get(ctx, "register");
455+
auto outDims = standardOutDimNames(ctx, rank);
456+
457+
// We use backward inference here as it is striclty more general
458+
Attribute inferSrc;
459+
auto dialect =
460+
resTy.getEncoding()
461+
.getDialect()
462+
.getRegisteredInterface<triton::DialectInferLayoutInterface>();
463+
assert(dialect);
464+
if (failed(dialect->inferFp4ToFpOpEncoding(
465+
resTy.getShape(), axis, resTy.getEncoding(), inferSrc,
466+
/*fwdInference*/ false, std::nullopt))) {
467+
return op->emitError() << "failed to infer encoding";
468+
}
469+
if (!areLayoutsEquivalent(srcTy.getShape(),
470+
cast<LayoutEncodingTrait>(inferSrc),
471+
cast<LayoutEncodingTrait>(srcTy.getEncoding())))
472+
return op->emitError()
473+
<< "Src and Dst encodings are not compatible:\n"
474+
<< toLinearLayout(srcTy.getShape(), inferSrc).toString() << "\n"
475+
<< srcLl.toString();
445476
return success();
446477
}
447478

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) {
431431

432432
static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) {
433433
Attribute srcEnc;
434-
auto shape = op.getSrc().getType().getShape();
434+
auto shape = op.getType().getShape();
435435
if (succeeded(
436436
dstEnc.getDialect()
437437
.getRegisteredInterface<triton::DialectInferLayoutInterface>()

test/TritonGPU/combine.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#layout2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
77
#layout3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
88

9+
#layout4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
10+
#layout5 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
11+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[16, 0], [32, 0]], block = []}>
12+
913

1014
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
1115

@@ -78,6 +82,16 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
7882
tt.return
7983
}
8084

85+
// CHECK-LABEL: fp4_keep_convert
86+
tt.func @fp4_keep_convert() -> tensor<64x64xf16, #linear> {
87+
%0 = arith.constant dense<0> : tensor<64x32xi8, #layout4>
88+
%fp4 = ttg.fp4_to_fp %0 {axis = 1 : i32} : tensor<64x32xi8, #layout4> -> tensor<64x64xf16, #layout5>
89+
%converted = ttg.convert_layout %fp4 : tensor<64x64xf16, #layout5> -> tensor<64x64xf16, #linear>
90+
// CHECK: ttg.fp4_to_fp
91+
// CHECK-NOT: ttg.convert_layout
92+
tt.return %converted : tensor<64x64xf16, #linear>
93+
}
94+
8195
// Hoist the convert on top of ext to make it cheaper.
8296
// CHECK-LABEL: hoist_above_ext
8397
tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> {

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,33 @@ LogicalResult ScaledUpcastFp4Op::verify() {
432432
return emitError() << "scale and output should have the same shape";
433433

434434
// Reuse Fp4ToFpOp's verifier to check types of input and output
435-
return triton::gpu::Fp4ToFpOp::verifyFp4ToFp(*this, inputTy, outputTy, axis);
435+
auto rank = inputTy.getRank();
436+
437+
if (rank != outputTy.getRank())
438+
return emitError() << "source rank " << rank << " != result rank "
439+
<< outputTy.getRank();
440+
441+
auto srcShape = inputTy.getShape();
442+
auto resShape = outputTy.getShape();
443+
444+
if (!(0 <= axis && axis < rank))
445+
return emitError() << "axis " << axis << " out of range for rank " << rank;
446+
447+
for (int i = 0; i < rank; ++i) {
448+
if (i == axis) {
449+
if (resShape[i] != srcShape[i] * 2)
450+
return emitError() << "axis " << axis
451+
<< " dimension must be 2x source dimension (src="
452+
<< srcShape[i] << ", dst=" << resShape[i] << ")";
453+
} else {
454+
if (resShape[i] != srcShape[i])
455+
return emitError() << "dimension " << i
456+
<< " mismatch (src=" << srcShape[i]
457+
<< ", dst=" << resShape[i] << ", axis=" << axis
458+
<< ")";
459+
}
460+
}
461+
return success();
436462
}
437463

438464
Attribute ScaledUpcastFp4Op::inferDstEncoding(unsigned opIdx,

0 commit comments

Comments
 (0)