Skip to content

Commit 8f76e08

Browse files
yzhang93weidel-p
authored andcommitted
[DispatchCreation] Add split reduction for weight backward convs (iree-org#22275)
Weight backward convolutions have a special CHWN layout, where the filter sizes (corresponding to output image sizes in forward convolutions) are typically large, while the output spatial dimensions are small. This makes the split reduction strategy particularly effective. This PR adds support to split these convs along the input channel dimension. Some experimental thresholds are applied to filter out cases that won't benefit from splitting reduction. Particular checks include: - When the batch and output channel sizes are large, the workload tends to distributed across many workgroups, making split reduction little to no effect. - When the input spatial sizes are small while the batch and output channel sizes are relatively larger (medium size), split reduction often has no effect or even degrades performance. --------- Signed-off-by: yzhang93 <[email protected]> Signed-off-by: Philipp <[email protected]>
1 parent db4f657 commit 8f76e08

File tree

4 files changed

+209
-6
lines changed

4 files changed

+209
-6
lines changed

compiler/src/iree/compiler/DispatchCreation/SetSplitReductionSizes.cpp

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
78
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
89
#include "iree/compiler/DispatchCreation/Passes.h"
9-
10-
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1110
#include "llvm/Support/DebugLog.h"
1211
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1312

@@ -94,12 +93,17 @@ struct SetSplitReductionSizesPass final
9493
return;
9594
}
9695

97-
std::optional<SmallVector<int64_t>> tileSizes =
98-
getOuterReductionSizes(tilingOp);
99-
if (!tileSizes) {
96+
// --- Case 1: Outer reduction ---
97+
if (auto tileSizes = getOuterReductionSizes(tilingOp)) {
98+
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
99+
return;
100+
}
101+
102+
// --- Case 2: Generic weight backward convolution ---
103+
if (auto tileSizes = getWeightBackwardReductionSizes(tilingOp)) {
104+
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
100105
return;
101106
}
102-
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, tileSizes.value());
103107
});
104108
}
105109

@@ -143,6 +147,131 @@ struct SetSplitReductionSizesPass final
143147
}
144148
return tileSizes;
145149
}
150+
151+
/// Determines split reduction sizes for weight backward convolutions.
152+
/// These convolutions have a special CHWN layout, where the filter sizes
153+
/// (corresponding to output image sizes in forward convolutions) are
154+
/// typically large, while the output spatial dimensions are small. This makes
155+
/// the split reduction strategy particularly effective. Currently, splitting
156+
/// is only applied along the input channel dimension.
157+
std::optional<SmallVector<int64_t>>
158+
getWeightBackwardReductionSizes(PartialReductionOpInterface op) const {
159+
// First check if the input op is a convolution with CHWN layout.
160+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
161+
if (!linalgOp || !linalg::isaConvolutionOpInterface(linalgOp)) {
162+
LDBG() << "skipping op; not convolution";
163+
return std::nullopt;
164+
}
165+
166+
FailureOr<mlir::linalg::ConvolutionDimensions> convDims =
167+
mlir::linalg::inferConvolutionDims(linalgOp);
168+
if (failed(convDims)) {
169+
LDBG() << "skipping op; failed to infer convolution dims";
170+
return std::nullopt;
171+
}
172+
173+
if (convDims->inputChannel.empty() || convDims->outputChannel.empty() ||
174+
convDims->batch.empty() || convDims->filterLoop.empty()) {
175+
LDBG() << "skipping op; missing convolution dimensions";
176+
return std::nullopt;
177+
}
178+
179+
OpOperand *input = linalgOp.getDpsInputOperand(0);
180+
OpOperand *filter = linalgOp.getDpsInputOperand(1);
181+
OpOperand *output = linalgOp.getDpsInitOperand(0);
182+
183+
Value inputVal = input->get();
184+
Value filterVal = filter->get();
185+
Value outputVal = output->get();
186+
187+
ArrayRef<int64_t> inputShape =
188+
llvm::cast<ShapedType>(inputVal.getType()).getShape();
189+
ArrayRef<int64_t> filterShape =
190+
llvm::cast<ShapedType>(filterVal.getType()).getShape();
191+
ArrayRef<int64_t> outputShape =
192+
llvm::cast<ShapedType>(outputVal.getType()).getShape();
193+
194+
if (ShapedType::isDynamicShape(inputShape) ||
195+
ShapedType::isDynamicShape(filterShape) ||
196+
ShapedType::isDynamicShape(outputShape)) {
197+
LDBG() << "skipping op; has dynamic shape";
198+
return std::nullopt;
199+
}
200+
201+
AffineMap inputMap = linalgOp.getMatchingIndexingMap(input);
202+
AffineMap filterMap = linalgOp.getMatchingIndexingMap(filter);
203+
AffineMap outputMap = linalgOp.getMatchingIndexingMap(output);
204+
205+
std::optional<int64_t> batchLastDim = outputMap.getResultPosition(
206+
getAffineDimExpr(convDims->batch.back(), outputMap.getContext()));
207+
if (!batchLastDim || batchLastDim.value() != outputShape.size() - 1) {
208+
LDBG() << "skipping op; not batch last layout";
209+
return std::nullopt;
210+
}
211+
212+
std::optional<int64_t> inputChannelDim = filterMap.getResultPosition(
213+
getAffineDimExpr(convDims->inputChannel[0], filterMap.getContext()));
214+
std::optional<int64_t> filterDim = filterMap.getResultPosition(
215+
getAffineDimExpr(convDims->filterLoop[0], filterMap.getContext()));
216+
if (!inputChannelDim || !filterDim ||
217+
inputChannelDim.value() > filterDim.value()) {
218+
LDBG() << "skipping op; not channel first layout";
219+
return std::nullopt;
220+
}
221+
222+
std::optional<int64_t> outputChannelDim = outputMap.getResultPosition(
223+
getAffineDimExpr(convDims->outputChannel[0], outputMap.getContext()));
224+
if (!outputChannelDim) {
225+
LDBG() << "skipping op; has no output channel dim";
226+
return std::nullopt;
227+
}
228+
229+
std::optional<SmallVector<int64_t>> maybeSizes =
230+
getReductionDimSizes(op.getOperation());
231+
if (!maybeSizes) {
232+
LDBG() << "skipping op; failed to get reduction sizes";
233+
return std::nullopt;
234+
}
235+
236+
// The constants below are determined based on empirical data.
237+
const int64_t largeDimSize = 512;
238+
const int64_t mediumDimSize = 128;
239+
const int64_t smallDimSize = 32;
240+
241+
// When the batch and output channel sizes are large, the workload tends
242+
// to distributed across many workgroups, making split reduction little to
243+
// no effect.
244+
int64_t outputChannelSize = outputShape[outputChannelDim.value()];
245+
int64_t batchSize = outputShape[batchLastDim.value()];
246+
if (outputChannelSize >= largeDimSize && batchSize >= largeDimSize) {
247+
LDBG() << "skipping op; large output channel or batch size";
248+
return std::nullopt;
249+
}
250+
251+
// When the input spatial sizes are small while the batch and output channel
252+
// sizes are relatively larger, split reduction often has no effect or even
253+
// degrades performance.
254+
for (auto dim : convDims->filterLoop) {
255+
for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) {
256+
if (e.isFunctionOfDim(dim) && inputShape[idx] < smallDimSize &&
257+
outputChannelSize > mediumDimSize && batchSize > mediumDimSize) {
258+
LDBG() << "skipping op; small input spatial size";
259+
return std::nullopt;
260+
}
261+
}
262+
}
263+
264+
// Only split along the input channel dimension.
265+
// TODO(vivian): split more reduction dimensions if needed.
266+
int64_t cDim = inputChannelDim.value();
267+
SmallVector<int64_t> tileSizes = std::move(*maybeSizes);
268+
if (tileSizes[cDim] == 1) {
269+
LDBG() << "skipping op; input channel size equals to 1";
270+
return std::nullopt;
271+
}
272+
tileSizes[cDim] = std::ceil(float(tileSizes[cDim]) / largeDimSize);
273+
return tileSizes;
274+
}
146275
};
147276
} // namespace
148277
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ iree_lit_test_suite(
5454
"set_encoding_padding.mlir",
5555
"set_encoding_pipeline.mlir",
5656
"set_split_reduction_sizes.mlir",
57+
"set_split_reduction_sizes_conv.mlir",
5758
"sink_reshapes.mlir",
5859
"split_reduction.mlir",
5960
"tensor_pad_to_tensor_insert_slice.mlir",

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ iree_lit_test_suite(
5252
"set_encoding_padding.mlir"
5353
"set_encoding_pipeline.mlir"
5454
"set_split_reduction_sizes.mlir"
55+
"set_split_reduction_sizes_conv.mlir"
5556
"sink_reshapes.mlir"
5657
"split_reduction.mlir"
5758
"tensor_pad_to_tensor_insert_slice.mlir"
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-split-reduction-sizes))" --split-input-file > %t
2+
// RUN: FileCheck %s < %t
3+
4+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)>
5+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
6+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
7+
util.func public @conv_2d_chwn_chwf(%arg0: tensor<16x227x227x16xf32>, %arg1: tensor<16x225x225x64xf32>, %arg2: tensor<64x3x3x16xf32>) -> tensor<64x3x3x16xf32> {
8+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x227x227x16xf32>, tensor<16x225x225x64xf32>) outs(%arg2 : tensor<64x3x3x16xf32>) {
9+
^bb0(%in: f32, %in_3: f32, %out: f32):
10+
%12 = arith.mulf %in, %in_3 : f32
11+
%13 = arith.addf %out, %12 : f32
12+
linalg.yield %13 : f32
13+
} -> tensor<64x3x3x16xf32>
14+
util.return %0 : tensor<64x3x3x16xf32>
15+
}
16+
17+
// CHECK-LABEL: @conv_2d_chwn_chwf
18+
// CHECK: iree_linalg_ext.split_reduction = [1 : index, 225 : index, 225 : index]
19+
20+
// -----
21+
22+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
23+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
24+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
25+
util.func public @no_split_conv_2d_nhwc_fhwc(%arg0: tensor<16x227x227x16xf32>, %arg1: tensor<64x3x3x16xf32>, %arg2: tensor<16x225x225x64xf32>) -> tensor<16x225x225x64xf32> {
26+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x227x227x16xf32>, tensor<64x3x3x16xf32>) outs(%arg2 : tensor<16x225x225x64xf32>) {
27+
^bb0(%in: f32, %in_0: f32, %out: f32):
28+
%3 = arith.mulf %in, %in_0 : f32
29+
%4 = arith.addf %out, %3 : f32
30+
linalg.yield %4 : f32
31+
} -> tensor<16x225x225x64xf32>
32+
util.return %0 : tensor<16x225x225x64xf32>
33+
}
34+
35+
// CHECK-LABEL: @no_split_conv_2d_nhwc_fhwc
36+
// CHECK-NOT: iree_linalg_ext.split_reduction
37+
38+
// -----
39+
40+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)>
41+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
42+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
43+
util.func public @no_split_large_N_F_sizes(%arg0: tensor<16x98x50x1024xf32>, %arg1: tensor<16x96x48x1024xf32>, %arg2: tensor<1024x3x3x1024xf32>) -> tensor<1024x3x3x1024xf32> {
44+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x98x50x1024xf32>, tensor<16x96x48x1024xf32>) outs(%arg2 : tensor<1024x3x3x1024xf32>) {
45+
^bb0(%in: f32, %in_3: f32, %out: f32):
46+
%12 = arith.mulf %in, %in_3 : f32
47+
%13 = arith.addf %out, %12 : f32
48+
linalg.yield %13 : f32
49+
} -> tensor<1024x3x3x1024xf32>
50+
util.return %0 : tensor<1024x3x3x1024xf32>
51+
}
52+
53+
// CHECK-LABEL: @no_split_large_N_F_sizes
54+
// CHECK-NOT: iree_linalg_ext.split_reduction
55+
56+
// -----
57+
58+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)>
59+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
60+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
61+
util.func public @no_split_small_H_W_sizes(%arg0: tensor<16x26x18x288xf32>, %arg1: tensor<16x24x16x288xf32>, %arg2: tensor<288x3x3x288xf32>) -> tensor<288x3x3x288xf32> {
62+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x18x288xf32>, tensor<16x24x16x288xf32>) outs(%arg2 : tensor<288x3x3x288xf32>) {
63+
^bb0(%in: f32, %in_3: f32, %out: f32):
64+
%12 = arith.mulf %in, %in_3 : f32
65+
%13 = arith.addf %out, %12 : f32
66+
linalg.yield %13 : f32
67+
} -> tensor<288x3x3x288xf32>
68+
util.return %0 : tensor<288x3x3x288xf32>
69+
}
70+
71+
// CHECK-LABEL: @no_split_small_H_W_sizes
72+
// CHECK-NOT: iree_linalg_ext.split_reduction

0 commit comments

Comments
 (0)