@@ -174,12 +174,14 @@ static void _mkldnn_convolution_out (
174174 IntArrayRef padding,
175175 int64_t groups,
176176 bool is_channels_last,
177- const ideep::attr_t & op_attr) {
177+ const ideep::attr_t & op_attr,
178+ const ideep::prop_kind aprop_kind = ideep::prop_kind::forward) {
178179 auto memory_format = mkldnn_convolution_memory_format (input_t .ndimension (), is_channels_last);
179180 auto input = input_t .is_mkldnn () ? input_t : input_t .contiguous (memory_format);
180181 auto weight = weight_t .is_mkldnn () ? weight_t : weight_t .contiguous (memory_format);
181182 const ideep::tensor x = itensor_from_tensor (input, /* from_const_data_ptr*/ true );
182183 const ideep::tensor w = itensor_from_tensor (weight, /* from_const_data_ptr*/ true );
184+ auto algo = ideep::algorithm::convolution_direct;
183185 if (bias.defined ()) {
184186 const ideep::tensor b = itensor_from_tensor (bias, /* from_const_data_ptr*/ true );
185187 ideep::convolution_forward::compute_v3 (
@@ -194,7 +196,9 @@ static void _mkldnn_convolution_out (
194196 {padding.begin (), padding.end ()},
195197 groups,
196198 is_channels_last,
197- op_attr);
199+ op_attr,
200+ algo,
201+ aprop_kind);
198202 } else {
199203 ideep::convolution_forward::compute_v3 (
200204 x,
@@ -207,7 +211,9 @@ static void _mkldnn_convolution_out (
207211 {padding.begin (), padding.end ()},
208212 groups,
209213 is_channels_last,
210- op_attr);
214+ op_attr,
215+ algo,
216+ aprop_kind);
211217 }
212218}
213219
@@ -223,7 +229,8 @@ static Tensor _mkldnn_convolution(
223229 std::string_view attr = " none" ,
224230 torch::List<std::optional<at::Scalar>> scalars =
225231 torch::List<std::optional<at::Scalar>>(),
226- std::optional<std::string_view> algorithm = std::nullopt) {
232+ std::optional<std::string_view> algorithm = std::nullopt,
233+ const ideep::prop_kind aprop_kind = ideep::prop_kind::forward) {
227234 ideep::attr_t op_attr = ideep::attr_t ();
228235 if (attr != " none" ) {
229236 auto it = fusion_unary_attr_map ().find (attr);
@@ -265,7 +272,8 @@ static Tensor _mkldnn_convolution(
265272 padding_expanded,
266273 groups,
267274 use_channels_last,
268- op_attr);
275+ op_attr,
276+ aprop_kind);
269277
270278 if (input_t .is_mkldnn ()) {
271279 return MKLDNNTensor (y, input_t .options ());
@@ -310,6 +318,14 @@ Tensor mkldnn_convolution_pointwise(
310318 c10::impl::ExcludeDispatchKeyGuard edkg (c10::autograd_dispatch_keyset);
311319 bool use_channels_last =
312320 weight_t .is_mkldnn () || mkldnn_conv_use_channels_last (input_t , weight_t );
321+ auto aprop_kind = ideep::prop_kind::forward;
322+ bool maybe_backward = GradMode::is_enabled () &&
323+ (input_t .requires_grad () || weight_t .requires_grad () ||
324+ (bias_opt.has_value () && bias_opt->defined () &&
325+ bias_opt->requires_grad ()));
326+ if (!maybe_backward) {
327+ aprop_kind = ideep::prop_kind::forward_inference;
328+ }
313329 return _mkldnn_convolution (
314330 input_t ,
315331 weight_t ,
@@ -321,7 +337,8 @@ Tensor mkldnn_convolution_pointwise(
321337 use_channels_last,
322338 attr,
323339 scalars,
324- algorithm);
340+ algorithm,
341+ aprop_kind);
325342}
326343
327344
0 commit comments