@@ -138,18 +138,21 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
138138
139139 // FIXME [Dot LL]
140140 // Do for all DotOperandEncodingAttr once we have LLs for all of them
141- static bool isSupportedDotOpLayout (RankedTensorType type) {
142- auto layout = type.getEncoding ();
143- auto bitwidth = type.getElementType ().getIntOrFloatBitWidth ();
144- if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
141+ static bool isSupportedDotOpLayout (RankedTensorType srcTy,
142+ RankedTensorType dstTy) {
143+ auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding ());
144+ auto dstLayout = dstTy.getEncoding ();
145+ auto bitwidth = dstTy.getElementType ().getIntOrFloatBitWidth ();
146+ auto rank = dstTy.getRank ();
147+ if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
148+ auto vecWidth = 32 / bitwidth;
145149 auto kWidth = dot.getKWidth ();
146- // Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
147- // - kWidth == 8
148- // - kWidth == 4, bitwidth = 32
150+ auto kOrder = dot.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
149151 if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent ())) {
150- bool legacyLoweringIsBuggy =
151- kWidth >= 8 || (kWidth == 4 && bitwidth == 32 );
152- return legacyLoweringIsBuggy && mma.isAmpere ();
152+ auto needTrans = kOrder != srcLayout.getOrder ()[0 ];
153+ auto canUseLdmatrix =
154+ (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
155+ return !canUseLdmatrix && mma.isAmpere ();
153156 }
154157 if (isa<AMDMfmaEncodingAttr>(dot.getParent ()))
155158 return true ;
@@ -164,10 +167,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
164167 RankedTensorType dstTy = op.getType ();
165168 Attribute srcLayout = srcTy.getEncoding ();
166169 Attribute dstLayout = dstTy.getEncoding ();
167- if (isa<SharedEncodingAttr>(srcLayout) &&
168- (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
170+ assert (isa<SharedEncodingAttr>(srcLayout) && " Unexpected src layout " );
171+ if ( (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169172 dstLayout) ||
170- isSupportedDotOpLayout (dstTy))) {
173+ isSupportedDotOpLayout (srcTy, dstTy))) {
171174 return lowerSharedToDistributed (op, adaptor, getTypeConverter (),
172175 rewriter);
173176 }
0 commit comments