Skip to content

Commit bc18624

Browse files
committed
[mlir] vector.type_cast: disallow memrefs with layout in verifier
Summary: These are not supported by any of the code using `type_cast`. In the general case, such casting would require memrefs to handle a non-contiguous vector representation or misaligned vectors (e.g., if the offset of the source memref is not divisible by vector size, since offset in the target memref is expressed in the number of elements). Differential Revision: https://reviews.llvm.org/D76349
1 parent 4b0f1e1 commit bc18624

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,10 @@ static void print(OpAsmPrinter &p, TypeCastOp op) {
14831483
}
14841484

14851485
static LogicalResult verify(TypeCastOp op) {
1486+
MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
1487+
if (!canonicalType.getAffineMaps().empty())
1488+
return op.emitOpError("expects operand to be a memref with no layout");
1489+
14861490
auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
14871491
if (op.getResultMemRefType() != resultType)
14881492
return op.emitOpError("expects result type to be: ") << resultType;

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,3 +1046,10 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
10461046
// expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}}
10471047
%0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32
10481048
}
1049+
1050+
// -----
1051+
1052+
func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
1053+
// expected-error@+1 {{expects operand to be a memref with no layout}}
1054+
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
1055+
}

0 commit comments

Comments
 (0)