@@ -1305,16 +1305,29 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1305
1305
build (b, result, source, dest, offsetValues, sizeValues, strideValues);
1306
1306
}
1307
1307
1308
+ static SliceVerificationResult
1309
+ verifyInsertSliceOp (ShapedType srcType, ShapedType dstType,
1310
+ ArrayAttr staticOffsets, ArrayAttr staticSizes,
1311
+ ArrayAttr staticStrides,
1312
+ ShapedType *expectedType = nullptr ) {
1313
+ // insert_slice is the inverse of extract_slice, use the same type inference.
1314
+ auto expected = ExtractSliceOp::inferRankReducedResultType (
1315
+ srcType.getRank (), dstType.cast <RankedTensorType>(),
1316
+ extractFromI64ArrayAttr (staticOffsets),
1317
+ extractFromI64ArrayAttr (staticSizes),
1318
+ extractFromI64ArrayAttr (staticStrides))
1319
+ .cast <ShapedType>();
1320
+ if (expectedType)
1321
+ *expectedType = expected;
1322
+ return isRankReducedType (expected, srcType);
1323
+ }
1324
+
1308
1325
// / Verifier for InsertSliceOp.
1309
1326
LogicalResult InsertSliceOp::verify () {
1310
- // insert_slice is the inverse of extract_slice, use the same type inference.
1311
- auto expectedType = ExtractSliceOp::inferRankReducedResultType (
1312
- getSourceType ().getRank (), getType (),
1313
- extractFromI64ArrayAttr (static_offsets ()),
1314
- extractFromI64ArrayAttr (static_sizes ()),
1315
- extractFromI64ArrayAttr (static_strides ()));
1327
+ ShapedType expectedType;
1316
1328
auto result =
1317
- isRankReducedType (expectedType.cast <ShapedType>(), getSourceType ());
1329
+ verifyInsertSliceOp (getSourceType (), getType (), static_offsets (),
1330
+ static_sizes (), static_strides (), &expectedType);
1318
1331
return produceSliceErrorMsg (result, *this , expectedType);
1319
1332
}
1320
1333
@@ -1446,12 +1459,20 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1446
1459
if (!sourceCastSource && !destCastSource)
1447
1460
return failure ();
1448
1461
1462
+ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source ());
1463
+ auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest ());
1464
+
1465
+ auto srcType = src.getType ().cast <ShapedType>();
1466
+ auto dstType = dst.getType ().cast <ShapedType>();
1467
+ if (verifyInsertSliceOp (srcType, dstType, insertSliceOp.static_offsets (),
1468
+ insertSliceOp.static_sizes (),
1469
+ insertSliceOp.static_strides ()) !=
1470
+ SliceVerificationResult::Success)
1471
+ return failure ();
1472
+
1449
1473
Value replacement = rewriter.create <InsertSliceOp>(
1450
- insertSliceOp.getLoc (),
1451
- (sourceCastSource ? *sourceCastSource : insertSliceOp.source ()),
1452
- (destCastSource ? *destCastSource : insertSliceOp.dest ()),
1453
- insertSliceOp.getMixedOffsets (), insertSliceOp.getMixedSizes (),
1454
- insertSliceOp.getMixedStrides ());
1474
+ insertSliceOp.getLoc (), src, dst, insertSliceOp.getMixedOffsets (),
1475
+ insertSliceOp.getMixedSizes (), insertSliceOp.getMixedStrides ());
1455
1476
1456
1477
if (replacement.getType () != insertSliceOp.getType ()) {
1457
1478
replacement = rewriter.create <tensor::CastOp>(
0 commit comments