Skip to content

Commit 4334375

Browse files
Max191qedawkins
authored andcommitted
Revert "[mlir][vector] Support n-D vectors in i8 to i4 trunci emulation (llvm#94946)"
This reverts commit 137a745.
1 parent c5bb6d3 commit 4334375

File tree

2 files changed

+45
-36
lines changed

2 files changed

+45
-36
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,8 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
912912
return rewriter.create<vector::InterleaveOp>(loc, low, high);
913913
}
914914

915-
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
916-
/// ops that take advantage of high-level information to avoid leaving LLVM to
915+
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
916+
/// that take advantage of high-level information to avoid leaving LLVM to
917917
/// scramble with peephole optimizations.
918918
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
919919
Value srcValue) {
@@ -922,22 +922,39 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
922922
"Expected i8 type");
923923

924924
// 1. De-interleave low and high i8 elements.
925-
auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
925+
int64_t vecDimSize = srcVecType.getShape().back();
926+
SmallVector<int64_t> deinterleaveLowMaskValues;
927+
SmallVector<int64_t> deinterleaveHighMaskValues;
928+
assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
929+
deinterleaveLowMaskValues.reserve(vecDimSize / 2);
930+
deinterleaveHighMaskValues.reserve(vecDimSize / 2);
931+
for (int i = 0, end = vecDimSize; i < end; i += 2) {
932+
deinterleaveLowMaskValues.push_back(i);
933+
deinterleaveHighMaskValues.push_back(i + 1);
934+
}
935+
936+
auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
937+
loc, srcValue, srcValue,
938+
rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
939+
auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
940+
loc, srcValue, srcValue,
941+
rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
926942

927943
// 2. Zero out the upper side of each low i8 element.
928944
constexpr int8_t i8LowBitMask = 0x0F;
929-
VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
930945
Value zeroOutMask = rewriter.create<arith::ConstantOp>(
931-
loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
932-
Value zeroOutLow = rewriter.create<arith::AndIOp>(
933-
loc, deinterleaveOp.getRes1(), zeroOutMask);
946+
loc,
947+
DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
948+
Value zeroOutLow =
949+
rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
934950

935951
// 3. Move high i4 values to upper side of the byte.
936952
constexpr int8_t bitsToShift = 4;
953+
VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
937954
auto shiftValues = rewriter.create<arith::ConstantOp>(
938955
loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
939-
Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
940-
shiftValues);
956+
Value shlHigh =
957+
rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
941958

942959
// 4. Merge high and low i4 values.
943960
auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
@@ -1131,7 +1148,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11311148
}
11321149
};
11331150

1134-
/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1151+
/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
11351152
/// bitwise ops that take advantage of high-level information to avoid leaving
11361153
/// LLVM to scramble with peephole optimizations.
11371154
///
@@ -1141,11 +1158,13 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11411158
///
11421159
/// %cst = arith.constant dense<15> : vector<4xi8>
11431160
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1144-
/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1145-
/// %2 = arith.andi %0, %cst : vector<4xi8>
1146-
/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1147-
/// %4 = arith.ori %2, %3 : vector<4xi8>
1148-
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1161+
/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1162+
/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1163+
/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1164+
/// %3 = arith.andi %1, %cst : vector<4xi8>
1165+
/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
1166+
/// %5 = arith.ori %3, %4 : vector<4xi8>
1167+
/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
11491168
///
11501169
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
11511170
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1159,6 +1178,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
11591178
if (!srcVecType || !dstVecType)
11601179
return failure();
11611180

1181+
// Only single dim vectors are supported until we have
1182+
// `vector.deinterleave`.
1183+
if (srcVecType.getRank() != 1)
1184+
return failure();
1185+
11621186
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
11631187
return failure();
11641188

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
268268
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
269269
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
270270
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
271-
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
271+
// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
272+
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
272273
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
273274
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
274275
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -282,7 +283,8 @@ func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
282283
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
283284
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
284285
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
285-
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
286+
// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
287+
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
286288
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
287289
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
288290
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -297,34 +299,17 @@ func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
297299
// CHECK-NOT: vector.andi
298300
// CHECK-NOT: vector.shli
299301
// CHECK-NOT: vector.ori
300-
// CHECK: arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
301-
// CHECK-NOT: arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
302-
// CHECK: vector.deinterleave
302+
// CHECK: arith.trunci
303303
%0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
304304
return %0 : vector<8x32xi4>
305305
}
306306

307-
// CHECK-LABEL: func.func @aligned_trunci_nd(
308-
// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
309-
func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
310-
// CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
311-
// CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
312-
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
313-
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
314-
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
315-
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
316-
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
317-
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
318-
%0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
319-
return %0 : vector<3x8x32xi4>
320-
}
321-
322307
// CHECK-LABEL: func.func @i4_transpose(
323308
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
324309
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
325310
// CHECK: %[[EXT:.*]] = vector.interleave
326311
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
327-
// CHECK: vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
312+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
328313
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
329314
return %0 : vector<16x8xi4>
330315
}

0 commit comments

Comments
 (0)