Skip to content

Commit 6afc767

Browse files
authored
Implement DotOperandEncodingAttr::getSizePerThread with block layout parent (triton-lang#5863)
For XPU backend, the logic of the common code is slightly changed and some Triton lit tests encounter the problem of an unimplemented function. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent cdf49bf commit 6afc767

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,10 +2213,11 @@ SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const {
22132213
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
22142214
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
22152215
return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx());
2216+
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(parentLayout)) {
2217+
return blocked.getSizePerThread();
22162218
} else {
22172219
llvm::report_fatal_error(
2218-
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
2219-
"supported yet");
2220+
"getSizePerThread not implemented for DotOperandEncodingAttr");
22202221
return {};
22212222
}
22222223
}

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
516516
distributedEncodings.push_back(blockedEncoding);
517517
distributedEncodings.push_back(
518518
triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding));
519+
// Create an opIdx=0 and opIdx=1 encoding
520+
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
521+
distributedEncodings.push_back(
522+
triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx,
523+
blockedEncoding, 0));
524+
}
519525
}
520526
}
521527
}
@@ -538,6 +544,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
538544
}
539545
}
540546

547+
auto is_dot_op_with_block_parent = [](Attribute layout) {
548+
auto dot_layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout);
549+
return dot_layout &&
550+
isa<triton::gpu::BlockedEncodingAttr>(dot_layout.getParent());
551+
};
552+
541553
for (const auto &distributedEncoding : distributedEncodings) {
542554
for (auto shape : shapes) {
543555
if (auto sliceEncoding =
@@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
558570
// Test that methods of DistributedEncoding return the same values
559571
Type eltTy = Float32Type::get(&ctx);
560572

561-
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
573+
if (!is_dot_op_with_block_parent(distributedEncoding)) {
574+
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
575+
}
562576
ASSERT_EQ(distributedEncoding.getTotalElemsPerThread(shape),
563577
linearEncoding.getTotalElemsPerThread(shape));
564578
ASSERT_EQ(distributedEncoding.getElemsPerThread(shape),
565579
linearEncoding.getElemsPerThread(shape));
566-
ASSERT_EQ(distributedEncoding.getRepOrder(),
567-
linearEncoding.getRepOrder());
568-
ASSERT_EQ(distributedEncoding.getContigPerThread(),
569-
linearEncoding.getContigPerThread());
580+
if (!is_dot_op_with_block_parent(distributedEncoding)) {
581+
ASSERT_EQ(distributedEncoding.getRepOrder(),
582+
linearEncoding.getRepOrder());
583+
ASSERT_EQ(distributedEncoding.getContigPerThread(),
584+
linearEncoding.getContigPerThread());
585+
}
570586
// DotOperandEncodingAttr::getWarpOrder() is not defined
571587
if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
572588
ASSERT_EQ(distributedEncoding.getWarpOrder(),
573589
linearEncoding.getWarpOrder());
574590
}
575-
ASSERT_EQ(distributedEncoding.getThreadOrder(),
576-
linearEncoding.getThreadOrder());
591+
if (!is_dot_op_with_block_parent(distributedEncoding)) {
592+
ASSERT_EQ(distributedEncoding.getThreadOrder(),
593+
linearEncoding.getThreadOrder());
594+
}
577595
// For slice these do not equal the total number of lines / warps
578596
// See [Note. Divergence of methods wrt. legacy layouts]
579597
if (!isa<triton::gpu::SliceEncodingAttr>(distributedEncoding)) {
580598
ASSERT_EQ(distributedEncoding.getWarpsPerCTA(),
581599
linearEncoding.getWarpsPerCTA());
582-
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
583-
linearEncoding.getThreadsPerWarp());
600+
if (!is_dot_op_with_block_parent(distributedEncoding)) {
601+
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
602+
linearEncoding.getThreadsPerWarp());
603+
}
584604
}
585605
// Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes
586606
// the second repetition along K as the second tile.
@@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
602622
// If we are not using CGAs, the order is meaningless
603623
auto useCGA =
604624
baseEncoding.getCTAsPerCGA() != SmallVector<unsigned>(rank, 1);
605-
if (useCGA) {
625+
if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) {
606626
ASSERT_EQ(baseEncoding.getCTAOrder(), linearEncoding.getCTAOrder());
607627
}
608628
}

0 commit comments

Comments
 (0)