@@ -155,27 +155,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
155155 // iterator_types is an auto-generated method.
156156}
157157
158- // / Helper to create a typical indexing map for MatmulOp. Returns a list of
159- // / AffineMap.
160- static SmallVector<AffineMap, 3 >
161- getDefaultIndexingMapsForMatmul (MLIRContext *context) {
162- AffineExpr d0, d1, d2;
163- SmallVector<AffineMap, 3 > indexingMaps;
164- bindDims (context, d0, d1, d2);
165- indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d2}, context));
166- indexingMaps.push_back (AffineMap::get (3 , 0 , {d2, d1}, context));
167- indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d1}, context));
168- return indexingMaps;
169- }
170-
171- // / Wrapper to return the typical indexing map array attribute for MatmulOp.
172- static SmallVector<Attribute>
173- getDefaultMatmulIndexingMapAttr (MLIRContext *context) {
174- return llvm::map_to_vector (
175- getDefaultIndexingMapsForMatmul (context),
176- [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
177- }
178-
179158// / Creates a structured operation given `inputs`, `outputs`, and `attributes`.
180159// / The result types are derived automatically if `resultTensorTypes` is none.
181160// / The body of the operation is filled using `regionBuilder`. All ods-gen
@@ -208,24 +187,18 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
208187 state.attributes .getAttrs (), regionBuilder);
209188}
210189
211- static void
212- buildMatmulOp (OpBuilder &b, OperationState &state ,
213- std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214- ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215- RegionBuilderFn regionBuilder,
216- std::optional< ArrayRef<AffineMap>> indexingMaps = std:: nullopt ) {
217- // Initialize indexingMaps, for MatmulOp.
190+ static void buildMatmulOp (OpBuilder &b, OperationState &state,
191+ std::optional<TypeRange> resultTensorTypes ,
192+ ValueRange inputs, ValueRange outputs ,
193+ ArrayRef<NamedAttribute> attributes,
194+ RegionBuilderFn regionBuilder,
195+ ArrayRef<AffineMap> indexingMaps) {
196+ // Initialize indexingMaps attribute , for MatmulOp.
218197 SmallVector<Attribute, 3 > indexingMapsAttrVal;
219- if (indexingMaps.has_value ()) {
220- for (mlir::AffineMap map : *indexingMaps) {
221- // Convert each AffineMap to an AffineMapAttr
222- indexingMapsAttrVal.push_back (AffineMapAttr::get (map));
223- }
224- state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
225- } else {
226- indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr (b.getContext ());
227- state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
228- }
198+ indexingMapsAttrVal = llvm::map_to_vector (
199+ MatmulOp::getDefaultIndexingMaps (b.getContext ()),
200+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
201+ state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
229202 return buildStructuredOp (b, state, resultTensorTypes, inputs, outputs,
230203 attributes, regionBuilder);
231204}
@@ -3457,7 +3430,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34573430 unsigned opIndex) {
34583431 SmallVector<AffineMap, 3 > opIndexingMaps = matmulOp.getIndexingMapsArray ();
34593432 SmallVector<AffineMap, 3 > defaultIndexingMaps =
3460- matmulOp.getDefaultIndexingMaps ();
3433+ matmulOp.getDefaultIndexingMaps (matmulOp-> getContext () );
34613434
34623435 auto opIndexingMap = opIndexingMaps[opIndex];
34633436 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3484,6 +3457,17 @@ namespace linalg {
34843457// MatMulOp
34853458// ===----------------------------------------------------------------------===//
34863459
3460+ // / Returns a list of AffineMap with the typical matmul indexing charactristic.
3461+ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps (MLIRContext *context) {
3462+ AffineExpr d0, d1, d2;
3463+ SmallVector<AffineMap, 3 > indexingMaps;
3464+ bindDims (context, d0, d1, d2);
3465+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d2}, context));
3466+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d2, d1}, context));
3467+ indexingMaps.push_back (AffineMap::get (3 , 0 , {d0, d1}, context));
3468+ return indexingMaps;
3469+ }
3470+
34873471SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray () {
34883472 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
34893473 utils::IteratorType::parallel,
@@ -3501,7 +3485,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
35013485// / Check if the op has broadcast and/or transpose semantic. Returns true if
35023486// / the user defined indexing maps are not equal to default map.
35033487bool MatmulOp::hasUserDefinedMaps () {
3504- SmallVector<AffineMap, 3 > defaultMaps = getDefaultIndexingMaps ();
3488+ SmallVector<AffineMap, 3 > defaultMaps =
3489+ getDefaultIndexingMaps (this ->getContext ());
35053490 SmallVector<AffineMap, 3 > explicitMaps = getIndexingMapsArray ();
35063491 return defaultMaps != explicitMaps;
35073492}
@@ -3535,13 +3520,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35353520 helper.yieldOutputs (yields);
35363521}
35373522
3538- // / Returns a list of AffineMap with the typical matmul indexing
3539- // / charactristic.
3540- SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps () {
3541- MLIRContext *context = this ->getContext ();
3542- return getDefaultIndexingMapsForMatmul (context);
3543- }
3544-
35453523// / Returns true if the given broadcast map \p bcastMap is valid for this op.
35463524bool MatmulOp::isValidLhsRhsBroadcastMap (AffineMap bcastMap) {
35473525 assert (bcastMap.getNumResults () == 1 && " Expected single result dim expr." );
@@ -3578,7 +3556,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
35783556 }
35793557 // Initialize indexingMaps, if not supplied explicitly.
35803558 if (indexingMapsAttr.empty ()) {
3581- indexingMapsAttr = getDefaultMatmulIndexingMapAttr (result.getContext ());
3559+ indexingMapsAttr = llvm::map_to_vector (
3560+ MatmulOp::getDefaultIndexingMaps (parser.getContext ()),
3561+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
35823562 }
35833563 result.addAttribute (" indexing_maps" ,
35843564 parser.getBuilder ().getArrayAttr (indexingMapsAttr));
@@ -3592,8 +3572,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
35923572 printNamedStructuredOp (p, getOperation (), getInputs (), getOutputs (),
35933573 elidedAttrs);
35943574
3595- SmallVector<Attribute, 3 > indexingMaps =
3596- getDefaultMatmulIndexingMapAttr (getContext ());
3575+ SmallVector<Attribute, 3 > indexingMaps = llvm::map_to_vector (
3576+ MatmulOp::getDefaultIndexingMaps (getContext ()),
3577+ [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
35973578 if (!llvm::equal (getIndexingMaps (), indexingMaps)) {
35983579 p << " indexing_maps = [" ;
35993580 llvm::interleaveComma (getIndexingMaps (), p,
0 commit comments