Skip to content

Commit cd0d095

Browse files
committed
[mlir][tensor] Check ops generated by InsertSliceOpCastFolder are valid
Fixes llvm/llvm-project#53099 Differential Revision: https://reviews.llvm.org/D119663
1 parent a9029a3 commit cd0d095

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,16 +1305,29 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
13051305
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
13061306
}
13071307

1308+
static SliceVerificationResult
1309+
verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1310+
ArrayAttr staticOffsets, ArrayAttr staticSizes,
1311+
ArrayAttr staticStrides,
1312+
ShapedType *expectedType = nullptr) {
1313+
// insert_slice is the inverse of extract_slice, use the same type inference.
1314+
auto expected = ExtractSliceOp::inferRankReducedResultType(
1315+
srcType.getRank(), dstType.cast<RankedTensorType>(),
1316+
extractFromI64ArrayAttr(staticOffsets),
1317+
extractFromI64ArrayAttr(staticSizes),
1318+
extractFromI64ArrayAttr(staticStrides))
1319+
.cast<ShapedType>();
1320+
if (expectedType)
1321+
*expectedType = expected;
1322+
return isRankReducedType(expected, srcType);
1323+
}
1324+
13081325
/// Verifier for InsertSliceOp.
13091326
LogicalResult InsertSliceOp::verify() {
1310-
// insert_slice is the inverse of extract_slice, use the same type inference.
1311-
auto expectedType = ExtractSliceOp::inferRankReducedResultType(
1312-
getSourceType().getRank(), getType(),
1313-
extractFromI64ArrayAttr(static_offsets()),
1314-
extractFromI64ArrayAttr(static_sizes()),
1315-
extractFromI64ArrayAttr(static_strides()));
1327+
ShapedType expectedType;
13161328
auto result =
1317-
isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
1329+
verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
1330+
static_sizes(), static_strides(), &expectedType);
13181331
return produceSliceErrorMsg(result, *this, expectedType);
13191332
}
13201333

@@ -1446,12 +1459,20 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
14461459
if (!sourceCastSource && !destCastSource)
14471460
return failure();
14481461

1462+
auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
1463+
auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
1464+
1465+
auto srcType = src.getType().cast<ShapedType>();
1466+
auto dstType = dst.getType().cast<ShapedType>();
1467+
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
1468+
insertSliceOp.static_sizes(),
1469+
insertSliceOp.static_strides()) !=
1470+
SliceVerificationResult::Success)
1471+
return failure();
1472+
14491473
Value replacement = rewriter.create<InsertSliceOp>(
1450-
insertSliceOp.getLoc(),
1451-
(sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1452-
(destCastSource ? *destCastSource : insertSliceOp.dest()),
1453-
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1454-
insertSliceOp.getMixedStrides());
1474+
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1475+
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
14551476

14561477
if (replacement.getType() != insertSliceOp.getType()) {
14571478
replacement = rewriter.create<tensor::CastOp>(

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,3 +1231,18 @@ func @splat_fold() -> tensor<4xf32> {
12311231
// CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
12321232
// CHECK-NEXT: return [[T]] : tensor<4xf32>
12331233
}
1234+
1235+
// -----
1236+
1237+
// There was an issue in cast + insert_slice folding generating invalid ir.
1238+
// https://github.com/llvm/llvm-project/issues/53099
1239+
// CHECK-LABEL: func @insert_slice_cast
1240+
func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
1241+
// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
1242+
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
1243+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
1244+
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
1245+
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
1246+
// CHECK: return %[[RES]] : tensor<?x?xf32>
1247+
return %1 : tensor<?x?xf32>
1248+
}

0 commit comments

Comments
 (0)