@@ -46,20 +46,6 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
4646 TORCH_INTERNAL_ASSERT (false );
4747};
4848
49- at::Tensor
50- mse_loss_backward_batch_rule (
51- const at::Tensor& grad_output,
52- const at::Tensor& self,
53- const at::Tensor& target,
54- int64_t reduction) {
55-
56- const auto result = 2 . * (self - target) * grad_output;
57- if (reduction == Reduction::Mean) {
58- return result / self.numel ();
59- }
60- return result;
61- };
62-
6349static Tensor apply_loss_reduction (const at::Tensor& unreduced, int64_t reduction) {
6450 if (reduction == at::Reduction::Mean) {
6551 return unreduced.mean ();
@@ -296,7 +282,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
296282 m.impl (" nll_loss_backward" , nll_loss_backward_decomposition);
297283 m.impl (" nll_loss2d_backward" , nll_loss_backward_decomposition);
298284 VMAP_SUPPORT (mse_loss, mse_loss_batch_rule);
299- m. impl ( " mse_loss_backward " , mse_loss_backward_batch_rule);
285+ // mse_loss_backwards uses a decomposition for its batch rule
300286 m.impl (" binary_cross_entropy" , binary_cross_entropy_plumbing);
301287 m.impl (" binary_cross_entropy_backward" , binary_cross_entropy_backward_plumbing);
302288}
0 commit comments