@@ -912,8 +912,8 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
912912 return rewriter.create <vector::InterleaveOp>(loc, low, high);
913913}
914914
915- // / Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
916- // / ops that take advantage of high-level information to avoid leaving LLVM to
915+ // / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
916+ // / that take advantage of high-level information to avoid leaving LLVM to
917917// / scramble with peephole optimizations.
918918static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
919919 Value srcValue) {
@@ -922,22 +922,39 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
922922 " Expected i8 type" );
923923
924924 // 1. De-interleave low and high i8 elements.
925- auto deinterleaveOp = rewriter.create <vector::DeinterleaveOp>(loc, srcValue);
925+ int64_t vecDimSize = srcVecType.getShape ().back ();
926+ SmallVector<int64_t > deinterleaveLowMaskValues;
927+ SmallVector<int64_t > deinterleaveHighMaskValues;
928+ assert ((vecDimSize % 2 ) == 0 && " Odd number of i4 elements" );
929+ deinterleaveLowMaskValues.reserve (vecDimSize / 2 );
930+ deinterleaveHighMaskValues.reserve (vecDimSize / 2 );
931+ for (int i = 0 , end = vecDimSize; i < end; i += 2 ) {
932+ deinterleaveLowMaskValues.push_back (i);
933+ deinterleaveHighMaskValues.push_back (i + 1 );
934+ }
935+
936+ auto lowShuffleOp = rewriter.create <vector::ShuffleOp>(
937+ loc, srcValue, srcValue,
938+ rewriter.getI64ArrayAttr (deinterleaveLowMaskValues));
939+ auto highShuffleOp = rewriter.create <vector::ShuffleOp>(
940+ loc, srcValue, srcValue,
941+ rewriter.getI64ArrayAttr (deinterleaveHighMaskValues));
926942
927943 // 2. Zero out the upper side of each low i8 element.
928944 constexpr int8_t i8LowBitMask = 0x0F ;
929- VectorType deinterI8VecType = deinterleaveOp.getResultVectorType ();
930945 Value zeroOutMask = rewriter.create <arith::ConstantOp>(
931- loc, DenseElementsAttr::get (deinterI8VecType, i8LowBitMask));
932- Value zeroOutLow = rewriter.create <arith::AndIOp>(
933- loc, deinterleaveOp.getRes1 (), zeroOutMask);
946+ loc,
947+ DenseElementsAttr::get (lowShuffleOp.getResultVectorType (), i8LowBitMask));
948+ Value zeroOutLow =
949+ rewriter.create <arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
934950
935951 // 3. Move high i4 values to upper side of the byte.
936952 constexpr int8_t bitsToShift = 4 ;
953+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType ();
937954 auto shiftValues = rewriter.create <arith::ConstantOp>(
938955 loc, DenseElementsAttr::get (deinterI8VecType, bitsToShift));
939- Value shlHigh = rewriter. create <arith::ShLIOp>(loc, deinterleaveOp. getRes2 (),
940- shiftValues);
956+ Value shlHigh =
957+ rewriter. create <arith::ShLIOp>(loc, highShuffleOp, shiftValues);
941958
942959 // 4. Merge high and low i4 values.
943960 auto mergedHiLowOp = rewriter.create <arith::OrIOp>(loc, zeroOutLow, shlHigh);
@@ -1131,7 +1148,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11311148 }
11321149};
11331150
1134- // / Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1151+ // / Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
11351152// / bitwise ops that take advantage of high-level information to avoid leaving
11361153// / LLVM to scramble with peephole optimizations.
11371154// /
@@ -1141,11 +1158,13 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11411158// /
11421159// / %cst = arith.constant dense<15> : vector<4xi8>
11431160// / %cst_0 = arith.constant dense<4> : vector<4xi8>
1144- // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1145- // / %2 = arith.andi %0, %cst : vector<4xi8>
1146- // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1147- // / %4 = arith.ori %2, %3 : vector<4xi8>
1148- // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1161+ // / %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1162+ // / %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1163+ // / %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1164+ // / %3 = arith.andi %1, %cst : vector<4xi8>
1165+ // / %4 = arith.shli %2, %cst_0 : vector<4xi8>
1166+ // / %5 = arith.ori %3, %4 : vector<4xi8>
1167+ // / %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
11491168// /
11501169struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
11511170 using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1159,6 +1178,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
11591178 if (!srcVecType || !dstVecType)
11601179 return failure ();
11611180
1181+ // Only single dim vectors are supported until we have
1182+ // `vector.deinterleave`.
1183+ if (srcVecType.getRank () != 1 )
1184+ return failure ();
1185+
11621186 if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
11631187 return failure ();
11641188
0 commit comments