Skip to content

Commit ba5d45d

Browse files
shiyang-wengpytorchmergebot
authored andcommitted
Add assertion to align with cuda (pytorch#153233)
Fixes pytorch#153137 Aligned batch_norm_cpu_out assertion to [batch_norm_cuda_out](https://github.com/pytorch/pytorch/blob/a7ea115494ab7fa5d8fbd260f295a737b946e00b/aten/src/ATen/native/cuda/Normalization.cu#L436). Pull Request resolved: pytorch#153233 Approved by: https://github.com/malfet
1 parent 5623d30 commit ba5d45d

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,11 @@ std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
770770

771771
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cpu_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
772772
bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
773+
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
774+
const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
775+
TORCH_CHECK_VALUE(has_running_mean == has_running_var,
776+
"running_mean and running_var must either both be None or neither be None");
777+
773778
// See [Note: hacky wrapper removal for optional tensor]
774779
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
775780
const Tensor& weight = *weight_maybe_owned;

aten/src/ATen/native/cuda/Normalization.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var,
435435
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
436436
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
437437
const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
438-
TORCH_CHECK(has_running_mean == has_running_var);
438+
TORCH_CHECK_VALUE(has_running_mean == has_running_var,
439+
"running_mean and running_var must either both be None or neither be None");
439440

440441
if (train) {
441442
batch_norm_mean_var(self, save_mean, save_invstd);

aten/src/ATen/native/mps/operations/Normalization.mm

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ static void get_shapes(MPSShape* input_shape_readonly,
103103

104104
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
105105
const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
106-
TORCH_CHECK(has_running_mean == has_running_var);
106+
TORCH_CHECK_VALUE(has_running_mean == has_running_var,
107+
"running_mean and running_var must either both be None or neither be None");
107108

108109
const bool has_weight = (weight_opt.has_value() && weight_opt->defined());
109110
const bool has_bias = (bias_opt.has_value() && bias_opt->defined());
@@ -587,10 +588,12 @@ Check if running mean exists (maybe do this check before making graph)
587588

588589
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
589590
const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined());
590-
TORCH_CHECK(has_running_mean == has_running_var);
591+
TORCH_CHECK_VALUE(has_running_mean == has_running_var,
592+
"running_mean and running_var must either both be None or neither be None");
591593
const bool has_save_mean = (save_mean_opt.has_value() && save_mean_opt->defined());
592594
const bool has_save_var = (save_var_opt.has_value() && save_var_opt->defined());
593-
TORCH_CHECK(has_save_mean == has_save_var);
595+
TORCH_CHECK_VALUE(has_save_mean == has_save_var,
596+
"save_mean and save_var must either both be None or neither be None");
594597

595598
const bool has_weight = (weight_opt.has_value() && weight_opt->defined());
596599

0 commit comments

Comments
 (0)