Skip to content

Commit f4e9aeb

Browse files
Revert "Update torch.masked.mean to upcast dtype for bool tensors (pytorch#139999)"
This reverts commit 0742b23. Reverted pytorch#139999 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it has a landrace and fails a test in trunk ([comment](pytorch#139999 (comment)))
1 parent 168c2cb commit f4e9aeb

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

torch/masked/_ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,16 +1384,8 @@ def mean(
13841384
{reduction_args}
13851385
13861386
{reduction_example}"""
1387-
dtype_source = "Optional"
13881387
if dtype is None:
13891388
dtype = input.dtype
1390-
dtype_source = "Input"
1391-
1392-
if not (dtype.is_floating_point or dtype.is_complex):
1393-
raise ValueError(
1394-
f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
1395-
f"a floating point or complex dtype. Got: {dtype}"
1396-
)
13971389
if input.layout == torch.strided:
13981390
if mask is None:
13991391
# TODO: compute count analytically

torch/testing/_internal/opinfo/definitions/_masked.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,26 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
769769
supports_forward_ad=True,
770770
supports_fwgrad_bwgrad=True,
771771
promotes_int_to_float=True,
772-
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
772+
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
773773
skips=(
774+
DecorateInfo(
775+
unittest.expectedFailure,
776+
"TestReductions",
777+
"test_ref_duplicate_values",
778+
dtypes=(torch.bool,),
779+
),
780+
DecorateInfo(
781+
unittest.expectedFailure,
782+
"TestReductions",
783+
"test_reference_masked",
784+
dtypes=(torch.bool,),
785+
),
786+
DecorateInfo(
787+
unittest.expectedFailure,
788+
"TestReductions",
789+
"test_ref_small_input",
790+
dtypes=(torch.bool,),
791+
),
774792
DecorateInfo(
775793
unittest.expectedFailure,
776794
"TestNormalizeOperators",

0 commit comments

Comments
 (0)