@@ -369,7 +369,7 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
369369}
370370
371371struct ConvToIgemmInfo {
372- bool isInputChannelLast ;
372+ bool isBatchDimLast ;
373373 bool isSpatialDimLast;
374374 linalg::ConvolutionDimensions convDims;
375375 DenseMap<int64_t , AffineExpr> convToIgemmDimMap;
@@ -392,14 +392,28 @@ getPaddingConvSizes(Builder &b, const SmallVector<int64_t> &bounds,
392392
393393 DenseMap<int64_t , AffineExpr> convToIgemmMap =
394394 convToIgemmInfo->convToIgemmDimMap ;
395- // Padding sizes for parallel dimensions are the same as workgroup tile
396- // sizes.
397395 DenseSet<int64_t > paddedIGEMMDims;
398396 DenseMap<int64_t , SmallVector<int64_t >> paddedReductionConvDims;
399397 linalg::ConvolutionDimensions convDims = convToIgemmInfo->convDims ;
400398 SetVector<int64_t > inputChannelDims (convDims.inputChannel .begin (),
401399 convDims.inputChannel .end ());
402400 SmallVector<int64_t > paddingConvSizes (convToIgemmMap.size (), 0 );
401+
402+ // For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid
403+ // introducing pad op as the producer of collapse_shape op which may cause
404+ // fusion problem.
405+ if (convToIgemmInfo->isBatchDimLast ) {
406+ int64_t lastBatchDim = convDims.batch .back ();
407+ auto IGEMMDimExpr = cast<AffineDimExpr>(convToIgemmMap[lastBatchDim]);
408+ unsigned IGEMMBatchPos = IGEMMDimExpr.getPosition ();
409+ if (paddingSizes[IGEMMBatchPos] &&
410+ bounds[IGEMMBatchPos] % paddingSizes[IGEMMBatchPos] == 0 ) {
411+ return std::nullopt ;
412+ }
413+ paddingConvSizes[lastBatchDim] = paddingSizes[IGEMMBatchPos];
414+ return b.getI64ArrayAttr (paddingConvSizes);
415+ }
416+
403417 for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
404418 auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
405419 unsigned IGEMMPos = IGEMMDimExpr.getPosition ();
@@ -415,19 +429,21 @@ getPaddingConvSizes(Builder &b, const SmallVector<int64_t> &bounds,
415429 // Only pad input channel dims. If we need to pad filter dims, then we
416430 // would rather just do padding on the GEMM instead.
417431 if (inputChannelDims.contains (convDim)) {
432+ // Multiple input channel dims for a single IGEMMPos is not supported.
433+ if (paddedIGEMMDims.contains (IGEMMPos)) {
434+ return std::nullopt ;
435+ }
418436 int64_t inputChannelSize =
419437 convToIgemmInfo->inputChannelDimToSize [convDim];
420438 bool isInputChannelSizeSmall =
421439 (paddingSizes[IGEMMPos] / inputChannelSize > 2 );
422- // The following cases are not supported:
423- // 1) Input channel is not the innermost dimension;
424- // 2) Input channel size is too small compared to padding size;
425- // 3) Multiple input channel dims for a single IGEMMPos.
426- if (!convToIgemmInfo->isInputChannelLast || isInputChannelSizeSmall ||
427- paddedIGEMMDims.contains (IGEMMPos)) {
428- return std::nullopt ;
440+ // If the input channel dimension is much smaller than the padding size,
441+ // skip padding along that dimension while still padding the others.
442+ if (isInputChannelSizeSmall) {
443+ paddingConvSizes[convDim] = 0 ;
444+ } else {
445+ paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
429446 }
430- paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
431447 paddedIGEMMDims.insert (IGEMMPos);
432448 }
433449 continue ;
@@ -766,16 +782,14 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
766782 kPackFactor = std::get<2 >(mmaKind.getMNKShape ());
767783 }
768784 paddingTileSizes[innerKDim] *= kPackFactor ;
785+ attrs.emplace_back (" padding" , b.getI64ArrayAttr (paddingTileSizes));
769786
770787 // Create `padding_conv` attribute when padding convolutions before IGEMM
771- // is possible, otherwise fallback to pad IGEMM .
788+ // is possible.
772789 if (auto attr =
773790 getPaddingConvSizes (b, bounds, paddingTileSizes, workgroupTileSizes,
774791 reductionTileSizes, convToIgemmInfo)) {
775- attrs.emplace_back (StringAttr::get (context, " padding_conv" ), *attr);
776- } else {
777- attrs.emplace_back (StringAttr::get (context, " padding" ),
778- b.getI64ArrayAttr (paddingTileSizes));
792+ attrs.emplace_back (" padding_conv" , *attr);
779793 }
780794 }
781795 auto configDict = DictionaryAttr::get (context, attrs);
@@ -812,13 +826,12 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
812826 auto inputType = llvm::cast<ShapedType>(op->getOperands ()[0 ].getType ());
813827 ArrayRef<int64_t > inputShape = inputType.getShape ();
814828 AffineMap inputMap = linalgOp.getIndexingMapsArray ()[0 ];
815- SmallVector<int64_t > inputChannelPos;
816829 SmallVector<int64_t > inputImagePos;
830+ SmallVector<int64_t > batchPos;
817831 for (auto dim : igemmGenericConvDetails->convDims .inputChannel ) {
818832 for (auto [idx, e] : llvm::enumerate (inputMap.getResults ())) {
819833 if (e.isFunctionOfDim (dim)) {
820834 convToIgemmInfo.inputChannelDimToSize [dim] = inputShape[idx];
821- inputChannelPos.push_back (idx);
822835 }
823836 }
824837 }
@@ -829,12 +842,19 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
829842 }
830843 }
831844 }
832- llvm::sort (inputChannelPos);
845+ for (auto dim : igemmGenericConvDetails->convDims .batch ) {
846+ for (auto [idx, e] : llvm::enumerate (inputMap.getResults ())) {
847+ if (e.isFunctionOfDim (dim)) {
848+ batchPos.push_back (idx);
849+ }
850+ }
851+ }
833852 llvm::sort (inputImagePos);
834- convToIgemmInfo.isInputChannelLast =
835- inputChannelPos.back () == inputShape.size () - 1 ;
853+ llvm::sort (batchPos);
854+ convToIgemmInfo.isBatchDimLast =
855+ !batchPos.empty () && batchPos.back () == inputShape.size () - 1 ;
836856 convToIgemmInfo.isSpatialDimLast =
837- inputImagePos.back () == inputShape.size () - 1 ;
857+ !inputImagePos. empty () && inputImagePos.back () == inputShape.size () - 1 ;
838858 convToIgemmInfo.convDims = igemmGenericConvDetails->convDims ;
839859 convToIgemmInfo.convToIgemmDimMap =
840860 igemmGenericConvDetails->convToIgemmDimMap ;
0 commit comments