@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
13101310 // The coordinates should be in shape of <? x rank>
13111311 unsigned expCOORank = stt.getLvlRank () - cooStartLvl;
13121312 if (cooTp.getRank () != 2 || expCOORank != cooTp.getShape ().back ()) {
1313- op->emitError (" input/output trailing COO level-ranks don't match" );
1313+ return op->emitError (" input/output trailing COO level-ranks don't match" );
13141314 }
13151315 }
13161316
@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
13501350}
13511351
13521352LogicalResult AssembleOp::verify () {
1353- const auto valuesTp = getRankedTensorType ( getValues ());
1353+ RankedTensorType valuesTp = getValues (). getType ( );
13541354 const auto lvlsTp = getLevels ().getTypes ();
13551355 const auto resTp = getSparseTensorType (getResult ());
13561356 return verifyPackUnPack (*this , true , resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
13641364 if (ot.getType () != rt.getType ())
13651365 return emitError (" output levels and return levels type mismatch" );
13661366
1367- const auto valuesTp = getRankedTensorType ( getRetValues ());
1367+ RankedTensorType valuesTp = getRetValues (). getType ( );
13681368 const auto lvlsTp = getRetLevels ().getTypes ();
13691369 const auto srcTp = getSparseTensorType (getTensor ());
13701370 return verifyPackUnPack (*this , false , srcTp, valuesTp, lvlsTp);
13711371}
13721372
13731373LogicalResult ConvertOp::verify () {
1374- if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource ().getType ())) {
1375- if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest ().getType ())) {
1376- if (tp1.getRank () != tp2.getRank ())
1377- return emitError (" unexpected conversion mismatch in rank" );
1378- auto dstEnc =
1379- llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding ());
1380- if (dstEnc && dstEnc.isSlice ())
1381- return emitError (" cannot convert to a sparse tensor slice" );
1382-
1383- auto shape1 = tp1.getShape ();
1384- auto shape2 = tp2.getShape ();
1385- // Accept size matches between the source and the destination type
1386- // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387- // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388- for (Dimension d = 0 , dimRank = tp1.getRank (); d < dimRank; d++)
1389- if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic )
1390- return emitError (" unexpected conversion mismatch in dimension " ) << d;
1391- return success ();
1392- }
1393- }
1394- return emitError (" unexpected type in convert" );
1374+ RankedTensorType tp1 = getSource ().getType ();
1375+ RankedTensorType tp2 = getDest ().getType ();
1376+ if (tp1.getRank () != tp2.getRank ())
1377+ return emitError (" unexpected conversion mismatch in rank" );
1378+ auto dstEnc =
1379+ llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding ());
1380+ if (dstEnc && dstEnc.isSlice ())
1381+ return emitError (" cannot convert to a sparse tensor slice" );
1382+
1383+ auto shape1 = tp1.getShape ();
1384+ auto shape2 = tp2.getShape ();
1385+ // Accept size matches between the source and the destination type
1386+ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387+ // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388+ for (Dimension d = 0 , dimRank = tp1.getRank (); d < dimRank; d++)
1389+ if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic )
1390+ return emitError (" unexpected conversion mismatch in dimension " ) << d;
1391+ return success ();
13951392}
13961393
13971394OpFoldResult ConvertOp::fold (FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
14951492 if (std::optional<uint64_t > lvl = getConstantLvlIndex ()) {
14961493 auto stt = getSparseTensorType (getSource ());
14971494 if (static_cast <uint64_t >(lvl.value ()) >= stt.getLvlRank ())
1498- emitError (" Level index exceeds the rank of the input sparse tensor" );
1495+ return emitError (
1496+ " Level index exceeds the rank of the input sparse tensor" );
14991497 }
15001498 return success ();
15011499}
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
16971695}
16981696
16991697LogicalResult ToSliceOffsetOp::verify () {
1700- auto rank = getRankedTensorType ( getSlice ()).getRank ();
1698+ auto rank = getSlice (). getType ( ).getRank ();
17011699 if (rank <= getDim ().getSExtValue () || getDim ().getSExtValue () < 0 )
17021700 return emitError (" requested dimension out of bound" );
17031701 return success ();
17041702}
17051703
17061704LogicalResult ToSliceStrideOp::verify () {
1707- auto rank = getRankedTensorType ( getSlice ()).getRank ();
1705+ auto rank = getSlice (). getType ( ).getRank ();
17081706 if (rank <= getDim ().getSExtValue () || getDim ().getSExtValue () < 0 )
17091707 return emitError (" requested dimension out of bound" );
17101708 return success ();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
19861984 const auto iTp = IndexType::get (getContext ());
19871985 for (Dimension d = 0 ; d < dimRank; d++)
19881986 if (args[d].getType () != iTp)
1989- emitError (
1987+ return emitError (
19901988 llvm::formatv (" Expecting Index type for argument at index {0}" , d));
19911989
19921990 const auto elemTp = t.getElementType ();
19931991 const auto valueTp = args[dimRank].getType ();
19941992 if (elemTp != valueTp)
1995- emitError (llvm::formatv (" Unmatched element type between input tensor and "
1996- " block argument, expected:{0}, got: {1}" ,
1997- elemTp, valueTp));
1993+ return emitError (
1994+ llvm::formatv (" Unmatched element type between input tensor and "
1995+ " block argument, expected:{0}, got: {1}" ,
1996+ elemTp, valueTp));
19981997 return success ();
19991998}
20001999
@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
20112010 SparseTensorType dstStt = getSparseTensorType (getResultCoo ());
20122011
20132012 if (!srcStt.isCOOType () || !dstStt.isCOOType ())
2014- emitError (" Expected COO sparse tensors only" );
2013+ return emitError (" Expected COO sparse tensors only" );
20152014
20162015 if (!srcStt.hasSameDimToLvl (dstStt))
2017- emitError (" Unmatched dim2lvl map between input and result COO" );
2016+ return emitError (" Unmatched dim2lvl map between input and result COO" );
20182017
20192018 if (srcStt.getPosType () != dstStt.getPosType () ||
20202019 srcStt.getCrdType () != dstStt.getCrdType () ||
20212020 srcStt.getElementType () != dstStt.getElementType ())
2022- emitError (" Unmatched storage format between input and result COO" );
2021+ return emitError (" Unmatched storage format between input and result COO" );
20232022
20242023 return success ();
20252024}
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
20442043 AffineMap xPerm = getPermMap ();
20452044 uint64_t nx = xPerm.getNumDims ();
20462045 if (nx < 1 )
2047- emitError (llvm::formatv (" Expected rank(perm_map) > 1, got {0}" , nx));
2046+ return emitError (llvm::formatv (" Expected rank(perm_map) > 1, got {0}" , nx));
20482047
20492048 if (!xPerm.isPermutation ())
2050- emitError (llvm::formatv (" Expected a permutation map, got {0}" , xPerm));
2049+ return emitError (
2050+ llvm::formatv (" Expected a permutation map, got {0}" , xPerm));
20512051
20522052 // We can't check the size of the buffers when n or buffer dimensions aren't
20532053 // compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
20562056 return success ();
20572057
20582058 // Verify dimensions.
2059- const auto checkDim = [&](Value v, Size minSize, const char *message) {
2059+ const auto checkDim = [&](Value v, Size minSize,
2060+ const char *message) -> LogicalResult {
20602061 const Size sh = getMemRefType (v).getShape ()[0 ];
20612062 if (!ShapedType::isDynamic (sh) && sh < minSize)
2062- emitError (llvm::formatv (" {0} got {1} < {2}" , message, sh, minSize));
2063+ return emitError (
2064+ llvm::formatv (" {0} got {1} < {2}" , message, sh, minSize));
2065+ return success ();
20632066 };
20642067 uint64_t n = cn.value ();
20652068 uint64_t ny = 0 ;
20662069 if (auto nyAttr = getNyAttr ())
20672070 ny = nyAttr.getInt ();
2068- checkDim (getXy (), n * (nx + ny),
2069- " Expected dimension(xy) >= n * (rank(perm_map) + ny)" );
2071+ if (failed (checkDim (getXy (), n * (nx + ny),
2072+ " Expected dimension(xy) >= n * (rank(perm_map) + ny)" )))
2073+ return failure ();
20702074 for (Value opnd : getYs ())
2071- checkDim (opnd, n, " Expected dimension(y) >= n" );
2075+ if (failed (checkDim (opnd, n, " Expected dimension(y) >= n" )))
2076+ return failure ();
20722077
20732078 return success ();
20742079}
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
21012106 }
21022107
21032108 if (lvlHi <= lvlLo)
2104- parser.emitError (parser.getNameLoc (),
2105- " expect larger level upper bound than lower bound" );
2109+ return parser.emitError (parser.getNameLoc (),
2110+ " expect larger level upper bound than lower bound" );
21062111
21072112 return success ();
21082113}
0 commit comments