@@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11251125 }
11261126
11271127 if (numGroups == 1 && inputZp) {
1128- // The quantized version uses a different channel ordering so we need to
1129- // permute the tensors in order to use the existing path. We should
1130- // eventually directly support this channel ordering.
1131- llvm::SmallVector<int64_t > inPerms, weightPerms;
1132- inPerms.push_back (0 ); // N stays at the front for input.
1133- // Then we expect the spatial dimensions
1134- for (size_t i = 0 ; i < numSpatialDims; ++i) {
1135- inPerms.push_back (i + 2 );
1136- weightPerms.push_back (i + 2 );
1137- }
1138- inPerms.push_back (1 );
1139- weightPerms.append ({1 , 0 });
1140-
1141- paddedInput = transposeValue (op.getLoc (), paddedInput, inPerms, rewriter);
1142- weight = transposeValue (op.getLoc (), weight, weightPerms, rewriter);
1143- outputTensor =
1144- transposeValue (op.getLoc (), outputTensor, inPerms, rewriter);
1145-
11461128 switch (numSpatialDims) {
11471129 case 2 :
11481130 conv = rewriter
1149- .create <linalg::Conv2DNhwcHwcfQOp >(
1131+ .create <linalg::Conv2DNchwFchwQOp >(
11501132 loc, outputTensor.getType (),
11511133 ValueRange{paddedInput, weight, inputZp, weightZp},
11521134 outputTensor, stridesAttr, dilationAttr)
11531135 .getResult (0 );
11541136 break ;
1155- case 3 :
1137+ case 3 : {
1138+ // The quantized version uses a different channel ordering so we need to
1139+ // permute the tensors in order to use the existing path. We should
1140+ // eventually directly support this channel ordering.
1141+ llvm::SmallVector<int64_t > inPerms, weightPerms;
1142+ inPerms.push_back (0 ); // N stays at the front for input.
1143+ // Then we expect the spatial dimensions
1144+ for (size_t i = 0 ; i < numSpatialDims; ++i) {
1145+ inPerms.push_back (i + 2 );
1146+ weightPerms.push_back (i + 2 );
1147+ }
1148+ inPerms.push_back (1 );
1149+ weightPerms.append ({1 , 0 });
1150+
1151+ paddedInput =
1152+ transposeValue (op.getLoc (), paddedInput, inPerms, rewriter);
1153+ weight = transposeValue (op.getLoc (), weight, weightPerms, rewriter);
1154+ outputTensor =
1155+ transposeValue (op.getLoc (), outputTensor, inPerms, rewriter);
1156+
11561157 conv = rewriter
11571158 .create <linalg::Conv3DNdhwcDhwcfQOp>(
11581159 loc, outputTensor.getType (),
11591160 ValueRange{paddedInput, weight, inputZp, weightZp},
11601161 outputTensor, stridesAttr, dilationAttr)
11611162 .getResult (0 );
1163+
1164+ llvm::SmallVector<int64_t > outPerms;
1165+ outPerms.push_back (0 );
1166+ outPerms.push_back (inPerms.size () - 1 );
1167+ for (size_t i = 0 ; i < numSpatialDims; ++i) {
1168+ outPerms.push_back (i + 1 );
1169+ }
1170+ conv = transposeValue (op.getLoc (), conv, outPerms, rewriter);
1171+
11621172 break ;
1173+ }
11631174 default :
11641175 return rewriter.notifyMatchFailure (
11651176 op, " unimplemented: only 1D, 2D, and 3D convolution supported" );
11661177 };
11671178
1168- llvm::SmallVector<int64_t > outPerms;
1169- outPerms.push_back (0 );
1170- outPerms.push_back (inPerms.size () - 1 );
1171- for (size_t i = 0 ; i < numSpatialDims; ++i) {
1172- outPerms.push_back (i + 1 );
1173- }
1174- conv = transposeValue (op.getLoc (), conv, outPerms, rewriter);
1175-
11761179 Type newResultType = getTypeConverter ()->convertType (op.getType ());
11771180 if (accumulatorDType != resultDTy) {
11781181 Type resultElementType =
0 commit comments