Skip to content

Commit 8bf5bfa

Browse files
yzhang93weidel-p
authored andcommitted
[Codegen] Allow pre-padding other dims of a conv except the input channel (iree-org#22296)
Previous PR disabled padding for all conv dimensions when input channel size is much smaller than the padding size. However, for backward conv CHWN layout, when batch and input channel dimensions are both unaligned, it is still useful to pad the batch dimension. This PR fixed iree-org#22277. --------- Signed-off-by: yzhang93 <[email protected]> Signed-off-by: Philipp <[email protected]>
1 parent a407f83 commit 8bf5bfa

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
369369
}
370370

371371
struct ConvToIgemmInfo {
372-
bool isInputChannelLast;
372+
bool isBatchDimLast;
373373
bool isSpatialDimLast;
374374
linalg::ConvolutionDimensions convDims;
375375
DenseMap<int64_t, AffineExpr> convToIgemmDimMap;
@@ -392,14 +392,28 @@ getPaddingConvSizes(Builder &b, const SmallVector<int64_t> &bounds,
392392

393393
DenseMap<int64_t, AffineExpr> convToIgemmMap =
394394
convToIgemmInfo->convToIgemmDimMap;
395-
// Padding sizes for parallel dimensions are the same as workgroup tile
396-
// sizes.
397395
DenseSet<int64_t> paddedIGEMMDims;
398396
DenseMap<int64_t, SmallVector<int64_t>> paddedReductionConvDims;
399397
linalg::ConvolutionDimensions convDims = convToIgemmInfo->convDims;
400398
SetVector<int64_t> inputChannelDims(convDims.inputChannel.begin(),
401399
convDims.inputChannel.end());
402400
SmallVector<int64_t> paddingConvSizes(convToIgemmMap.size(), 0);
401+
402+
// For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid
403+
// introducing pad op as the producer of collapse_shape op which may cause
404+
// fusion problem.
405+
if (convToIgemmInfo->isBatchDimLast) {
406+
int64_t lastBatchDim = convDims.batch.back();
407+
auto IGEMMDimExpr = cast<AffineDimExpr>(convToIgemmMap[lastBatchDim]);
408+
unsigned IGEMMBatchPos = IGEMMDimExpr.getPosition();
409+
if (paddingSizes[IGEMMBatchPos] &&
410+
bounds[IGEMMBatchPos] % paddingSizes[IGEMMBatchPos] == 0) {
411+
return std::nullopt;
412+
}
413+
paddingConvSizes[lastBatchDim] = paddingSizes[IGEMMBatchPos];
414+
return b.getI64ArrayAttr(paddingConvSizes);
415+
}
416+
403417
for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
404418
auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
405419
unsigned IGEMMPos = IGEMMDimExpr.getPosition();
@@ -415,19 +429,21 @@ getPaddingConvSizes(Builder &b, const SmallVector<int64_t> &bounds,
415429
// Only pad input channel dims. If we need to pad filter dims, then we
416430
// would rather just do padding on the GEMM instead.
417431
if (inputChannelDims.contains(convDim)) {
432+
// Multiple input channel dims for a single IGEMMPos is not supported.
433+
if (paddedIGEMMDims.contains(IGEMMPos)) {
434+
return std::nullopt;
435+
}
418436
int64_t inputChannelSize =
419437
convToIgemmInfo->inputChannelDimToSize[convDim];
420438
bool isInputChannelSizeSmall =
421439
(paddingSizes[IGEMMPos] / inputChannelSize > 2);
422-
// The following cases are not supported:
423-
// 1) Input channel is not the innermost dimension;
424-
// 2) Input channel size is too small compared to padding size;
425-
// 3) Multiple input channel dims for a single IGEMMPos.
426-
if (!convToIgemmInfo->isInputChannelLast || isInputChannelSizeSmall ||
427-
paddedIGEMMDims.contains(IGEMMPos)) {
428-
return std::nullopt;
440+
// If the input channel dimension is much smaller than the padding size,
441+
// skip padding along that dimension while still padding the others.
442+
if (isInputChannelSizeSmall) {
443+
paddingConvSizes[convDim] = 0;
444+
} else {
445+
paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
429446
}
430-
paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
431447
paddedIGEMMDims.insert(IGEMMPos);
432448
}
433449
continue;
@@ -766,16 +782,14 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
766782
kPackFactor = std::get<2>(mmaKind.getMNKShape());
767783
}
768784
paddingTileSizes[innerKDim] *= kPackFactor;
785+
attrs.emplace_back("padding", b.getI64ArrayAttr(paddingTileSizes));
769786

770787
// Create `padding_conv` attribute when padding convolutions before IGEMM
771-
// is possible, otherwise fallback to pad IGEMM.
788+
// is possible.
772789
if (auto attr =
773790
getPaddingConvSizes(b, bounds, paddingTileSizes, workgroupTileSizes,
774791
reductionTileSizes, convToIgemmInfo)) {
775-
attrs.emplace_back(StringAttr::get(context, "padding_conv"), *attr);
776-
} else {
777-
attrs.emplace_back(StringAttr::get(context, "padding"),
778-
b.getI64ArrayAttr(paddingTileSizes));
792+
attrs.emplace_back("padding_conv", *attr);
779793
}
780794
}
781795
auto configDict = DictionaryAttr::get(context, attrs);
@@ -812,13 +826,12 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
812826
auto inputType = llvm::cast<ShapedType>(op->getOperands()[0].getType());
813827
ArrayRef<int64_t> inputShape = inputType.getShape();
814828
AffineMap inputMap = linalgOp.getIndexingMapsArray()[0];
815-
SmallVector<int64_t> inputChannelPos;
816829
SmallVector<int64_t> inputImagePos;
830+
SmallVector<int64_t> batchPos;
817831
for (auto dim : igemmGenericConvDetails->convDims.inputChannel) {
818832
for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) {
819833
if (e.isFunctionOfDim(dim)) {
820834
convToIgemmInfo.inputChannelDimToSize[dim] = inputShape[idx];
821-
inputChannelPos.push_back(idx);
822835
}
823836
}
824837
}
@@ -829,12 +842,19 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
829842
}
830843
}
831844
}
832-
llvm::sort(inputChannelPos);
845+
for (auto dim : igemmGenericConvDetails->convDims.batch) {
846+
for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) {
847+
if (e.isFunctionOfDim(dim)) {
848+
batchPos.push_back(idx);
849+
}
850+
}
851+
}
833852
llvm::sort(inputImagePos);
834-
convToIgemmInfo.isInputChannelLast =
835-
inputChannelPos.back() == inputShape.size() - 1;
853+
llvm::sort(batchPos);
854+
convToIgemmInfo.isBatchDimLast =
855+
!batchPos.empty() && batchPos.back() == inputShape.size() - 1;
836856
convToIgemmInfo.isSpatialDimLast =
837-
inputImagePos.back() == inputShape.size() - 1;
857+
!inputImagePos.empty() && inputImagePos.back() == inputShape.size() - 1;
838858
convToIgemmInfo.convDims = igemmGenericConvDetails->convDims;
839859
convToIgemmInfo.convToIgemmDimMap =
840860
igemmGenericConvDetails->convToIgemmDimMap;

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func.func @conv_chwn_chwf_unaligned_batch(%arg0: tensor<16x193x129x40xbf16>, %ar
220220
// CHECK-SAME: subgroup = [1, 1, 1, 1, 0]
221221
// CHECK-SAME: workgroup = [16, 1, 1, 16, 0]
222222

223-
// PAD-CONV-GFX942: padding_conv = [16, 1, 1, 16, 0, 0, 0]
223+
// PAD-CONV-GFX942: padding_conv = [0, 0, 0, 16, 0, 0, 0]
224224

225225
// -----
226226

@@ -305,19 +305,19 @@ module {
305305
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
306306
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
307307
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
308-
func.func @conv_chwn_chwf_no_pad_conv(%arg0: tensor<2x192x128x40xbf16>, %arg1: tensor<2x95x63x40xbf16>, %arg2: tensor<40x3x3x40xf32>) -> tensor<40x3x3x40xf32> {
309-
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x192x128x40xbf16>, tensor<2x95x63x40xbf16>) outs(%arg2 : tensor<40x3x3x40xf32>) {
308+
func.func @conv_chwn_chwf_aligned_batch(%arg0: tensor<2x192x128x48xbf16>, %arg1: tensor<2x95x63x40xbf16>, %arg2: tensor<40x3x3x48xf32>) -> tensor<40x3x3x48xf32> {
309+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x192x128x48xbf16>, tensor<2x95x63x40xbf16>) outs(%arg2 : tensor<40x3x3x48xf32>) {
310310
^bb0(%in: bf16, %in_0: bf16, %out: f32):
311311
%1 = arith.extf %in : bf16 to f32
312312
%2 = arith.extf %in_0 : bf16 to f32
313313
%3 = arith.mulf %1, %2 : f32
314314
%4 = arith.addf %out, %3 : f32
315315
linalg.yield %4 : f32
316-
} -> tensor<40x3x3x40xf32>
317-
return %0 : tensor<40x3x3x40xf32>
316+
} -> tensor<40x3x3x48xf32>
317+
return %0 : tensor<40x3x3x48xf32>
318318
}
319319

320-
// CHECK-LABEL: func.func @conv_chwn_chwf_no_pad_conv
320+
// CHECK-LABEL: func.func @conv_chwn_chwf_aligned_batch
321321
// PAD-CONV-GFX942: padding = [16, 1, 1, 16, 16]
322322
// PAD-CONV-GFX942-NOT: padding_conv
323323

@@ -326,7 +326,7 @@ func.func @conv_chwn_chwf_no_pad_conv(%arg0: tensor<2x192x128x40xbf16>, %arg1: t
326326
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
327327
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
328328
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
329-
func.func @conv_nhwc_small_channel_no_pad_conv(%arg0: tensor<16x26x19x3xf16>, %arg1: tensor<287x3x3x3xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
329+
func.func @conv_nhwc_small_channel_size(%arg0: tensor<16x26x19x3xf16>, %arg1: tensor<287x3x3x3xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
330330
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x3xf16>, tensor<287x3x3x3xf16>) outs(%arg2 : tensor<16x24x17x287xf32>) {
331331
^bb0(%in: f16, %in_0: f16, %out: f32):
332332
%1 = arith.extf %in : f16 to f32
@@ -338,6 +338,6 @@ func.func @conv_nhwc_small_channel_no_pad_conv(%arg0: tensor<16x26x19x3xf16>, %a
338338
return %0 : tensor<16x24x17x287xf32>
339339
}
340340

341-
// CHECK-LABEL: func.func @conv_nhwc_small_channel_no_pad_conv
342-
// PAD-CONV-GFX942: padding = [1, 4, 32, 64, 32]
343-
// PAD-CONV-GFX942-NOT: padding_conv
341+
// CHECK-LABEL: func.func @conv_nhwc_small_channel_size
342+
// PAD-CONV-GFX942: padding = [1, 4, 32, 64, 32]
343+
// PAD-CONV-GFX942: padding_conv = [1, 4, 32, 64, 0, 0, 0]

0 commit comments

Comments
 (0)