@@ -516,6 +516,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
516516 distributedEncodings.push_back (blockedEncoding);
517517 distributedEncodings.push_back (
518518 triton::gpu::SliceEncodingAttr::get (&ctx, 0 , blockedEncoding));
519+ // Create an opIdx=0 and opIdx=1 encoding
520+ for (unsigned opIdx = 0 ; opIdx < 2 ; ++opIdx) {
521+ distributedEncodings.push_back (
522+ triton::gpu::DotOperandEncodingAttr::get (&ctx, opIdx,
523+ blockedEncoding, 0 ));
524+ }
519525 }
520526 }
521527 }
@@ -538,6 +544,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
538544 }
539545 }
540546
547+ auto is_dot_op_with_block_parent = [](Attribute layout) {
548+ auto dot_layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout);
549+ return dot_layout &&
550+ isa<triton::gpu::BlockedEncodingAttr>(dot_layout.getParent ());
551+ };
552+
541553 for (const auto &distributedEncoding : distributedEncodings) {
542554 for (auto shape : shapes) {
543555 if (auto sliceEncoding =
@@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
558570 // Test that methods of DistributedEncoding return the same values
559571 Type eltTy = Float32Type::get (&ctx);
560572
561- ASSERT_EQ (getOrder (distributedEncoding), linearEncoding.getRepOrder ());
573+ if (!is_dot_op_with_block_parent (distributedEncoding)) {
574+ ASSERT_EQ (getOrder (distributedEncoding), linearEncoding.getRepOrder ());
575+ }
562576 ASSERT_EQ (distributedEncoding.getTotalElemsPerThread (shape),
563577 linearEncoding.getTotalElemsPerThread (shape));
564578 ASSERT_EQ (distributedEncoding.getElemsPerThread (shape),
565579 linearEncoding.getElemsPerThread (shape));
566- ASSERT_EQ (distributedEncoding.getRepOrder (),
567- linearEncoding.getRepOrder ());
568- ASSERT_EQ (distributedEncoding.getContigPerThread (),
569- linearEncoding.getContigPerThread ());
580+ if (!is_dot_op_with_block_parent (distributedEncoding)) {
581+ ASSERT_EQ (distributedEncoding.getRepOrder (),
582+ linearEncoding.getRepOrder ());
583+ ASSERT_EQ (distributedEncoding.getContigPerThread (),
584+ linearEncoding.getContigPerThread ());
585+ }
570586 // DotOperandEncodingAttr::getWarpOrder() is not defined
571587 if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
572588 ASSERT_EQ (distributedEncoding.getWarpOrder (),
573589 linearEncoding.getWarpOrder ());
574590 }
575- ASSERT_EQ (distributedEncoding.getThreadOrder (),
576- linearEncoding.getThreadOrder ());
591+ if (!is_dot_op_with_block_parent (distributedEncoding)) {
592+ ASSERT_EQ (distributedEncoding.getThreadOrder (),
593+ linearEncoding.getThreadOrder ());
594+ }
577595 // For slice these do not equal the total number of lines / warps
578596 // See [Note. Divergence of methods wrt. legacy layouts]
579597 if (!isa<triton::gpu::SliceEncodingAttr>(distributedEncoding)) {
580598 ASSERT_EQ (distributedEncoding.getWarpsPerCTA (),
581599 linearEncoding.getWarpsPerCTA ());
582- ASSERT_EQ (distributedEncoding.getThreadsPerWarp (),
583- linearEncoding.getThreadsPerWarp ());
600+ if (!is_dot_op_with_block_parent (distributedEncoding)) {
601+ ASSERT_EQ (distributedEncoding.getThreadsPerWarp (),
602+ linearEncoding.getThreadsPerWarp ());
603+ }
584604 }
585605 // Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes
586606 // the second repetition along K as the second tile.
@@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
602622 // If we are not using CGAs, the order is meaningless
603623 auto useCGA =
604624 baseEncoding.getCTAsPerCGA () != SmallVector<unsigned >(rank, 1 );
605- if (useCGA) {
625+ if (useCGA && ! is_dot_op_with_block_parent (distributedEncoding) ) {
606626 ASSERT_EQ (baseEncoding.getCTAOrder (), linearEncoding.getCTAOrder ());
607627 }
608628 }
0 commit comments