Skip to content

Commit ca61bcb

Browse files
GleasonKTensorFlow MLIR Team
authored andcommitted
[StableHLO] Add transpose simplification
PiperOrigin-RevId: 820804015
1 parent c64843c commit ca61bcb

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,15 @@ func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x
18101810
return %0 : tensor<2x4x1x5xf32>
18111811
}
18121812

1813+
// CHECK-LABEL: @transpose_of_transpose
1814+
func.func @transpose_of_transpose(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
1815+
%0 = stablehlo.transpose %arg0, dims = [3,2,1,0] : (tensor<1x2x3x4xf32>) -> tensor<4x3x2x1xf32>
1816+
%1 = stablehlo.transpose %0, dims = [3,2,1,0] : (tensor<4x3x2x1xf32>) -> tensor<1x2x3x4xf32>
1817+
// CHECK-NOT: stablehlo.transpose
1818+
// CHECK: return %arg0
1819+
return %1 : tensor<1x2x3x4xf32>
1820+
}
1821+
18131822
// -----
18141823

18151824
////////

stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,17 @@ struct SortOpSetDimension : public SimplifyOpRewritePattern<SortOp> {
13091309
// TransposeOp
13101310
/////////////////////////////////
13111311

1312+
DenseI64ArrayAttr getMergedTransposePermutation(OpBuilder& b,
1313+
ArrayRef<int64_t> childPerm,
1314+
ArrayRef<int64_t> parentPerm) {
1315+
SmallVector<int64_t> mergedPerm;
1316+
mergedPerm.reserve(parentPerm.size());
1317+
for (int64_t parentIdx : parentPerm) {
1318+
mergedPerm.push_back(childPerm[parentIdx]);
1319+
}
1320+
return b.getDenseI64ArrayAttr(mergedPerm);
1321+
}
1322+
13121323
// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X)
13131324
struct TransposeIsReshape final : SimplifyOpRewritePattern<TransposeOp> {
13141325
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;

stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def InvertBroadcastDims : NativeCodeCall<"getInvertedBroadcastDimensions($_build
120120

121121
def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">;
122122

123+
def MergePermutations : NativeCodeCall<"getMergedTransposePermutation($_builder, $0, $1)">;
124+
123125
def StableHLO_ConvertOpWithShape : NativeCodeCall<
124126
"$_builder.create<stablehlo::ConvertOp>($_loc, $0.getType(), $1)">;
125127

@@ -539,6 +541,12 @@ def TransposeOp_RemoveNoop
539541
: Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims),
540542
(replaceWithValue $lhs)>;
541543

544+
// Pattern: transpose(transpose(X)) -> transpose(X)
545+
def TransposeOp_TransposeOfTranspose
546+
: Pat<(StableHLO_TransposeOp
547+
(StableHLO_TransposeOp $child, $child_dims), $dims),
548+
(StableHLO_TransposeOp $child, (MergePermutations $child_dims, $dims))>;
549+
542550
////////
543551
// GetTupleElementOp
544552

0 commit comments

Comments
 (0)