@@ -143,44 +143,6 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
143143 }
144144};
145145
146- static Attribute inferSrcEncodingMemDescReshape (ArrayRef<int64_t > srcShape,
147- MemDescType dstType) {
148- auto dstEncoding = dstType.getEncoding ();
149- auto dstShape = dstType.getShape ();
150- auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(dstEncoding);
151- if (!mmaEncoding)
152- return {};
153- // TODO: supporting reshape of CTA layouts is non-trivial.
154- if (getNumCTAs (mmaEncoding) > 1 )
155- return {};
156- int innerDimDst =
157- mmaEncoding.getTransposed () ? dstShape.front () : dstShape.back ();
158- int innerDimSrc =
159- mmaEncoding.getTransposed () ? srcShape.front () : srcShape.back ();
160- // For now disallow reshape of the inner dimension.
161- if (innerDimDst != innerDimSrc)
162- return {};
163-
164- // CTALayout can be all 1's because we bailed on multi-CTA layouts above.
165- auto CTALayout = CTALayoutAttr::get (
166- dstEncoding.getContext (),
167- /* CTAsPerCGA=*/ SmallVector<unsigned >(srcShape.size (), 1 ),
168- /* CTASplitNum=*/ SmallVector<unsigned >(srcShape.size (), 1 ),
169- /* CTAOrder=*/ llvm::to_vector (llvm::seq<unsigned >(srcShape.size ())));
170- auto srcEncoding = NVMMASharedEncodingAttr::get (
171- dstEncoding.getContext (), mmaEncoding.getSwizzlingByteWidth (),
172- mmaEncoding.getTransposed (), mmaEncoding.getElementBitWidth (),
173- mmaEncoding.getFp4Padded (), CTALayout);
174- // Big guns, check linear layouts are equivalent
175- auto srcLL = toLinearLayout (srcShape, srcEncoding);
176- auto dstLL = toLinearLayout (dstShape, dstEncoding);
177- auto ctx = dstEncoding.getContext ();
178- if (reshapeLayout (ctx, srcLL, dstShape) != dstLL) {
179- return {};
180- }
181- return srcEncoding;
182- }
183-
184146// Rewrite
185147//
186148// alloc(reshape(), #shared1) ->
@@ -204,18 +166,21 @@ class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
204166 auto allocEncoding = allocType.getEncoding ();
205167
206168 RankedTensorType srcTy = reshapeOp.getSrc ().getType ();
207- auto newAllocEncoding =
208- inferSrcEncodingMemDescReshape (srcTy.getShape (), allocType);
209- if (!newAllocEncoding)
169+ auto srcShape = srcTy.getShape ();
170+ auto dstShape = allocType.getShape ();
171+
172+ // We use the fact that forward and backward inference are the same for
173+ // MemDescReshapeOp to infer the source MemDescType that would produce
174+ // `allocType` after a reshape.
175+ MemDescType innerTy;
176+ if (failed (MemDescReshapeOp::inferReturnTypes (
177+ getContext (), allocOp.getLoc (), allocType, srcShape, innerTy)))
210178 return failure ();
211179
212- MemDescType innerTy =
213- MemDescType::get (srcTy.getShape (), srcTy.getElementType (),
214- newAllocEncoding, allocType.getMemorySpace ());
215180 auto newAlloc = rewriter.create <LocalAllocOp>(allocOp.getLoc (), innerTy,
216181 reshapeOp.getSrc ());
217- rewriter.replaceOpWithNewOp <MemDescReshapeOp>(allocOp, allocOp. getType () ,
218- newAlloc );
182+ rewriter.replaceOpWithNewOp <MemDescReshapeOp>(allocOp, newAlloc ,
183+ allocOp. getType (). getShape () );
219184 return success ();
220185 }
221186};
0 commit comments