Skip to content

Commit a5051a9

Browse files
GeorgeWigleypytorchmergebot
authored andcommitted
Update torch.masked.mean to upcast dtype for bool tensors (pytorch#139999)
When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated. The below example shows how the incorrect result occurs: ``` a = torch.tensor([True, True]) count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2 total = torch.sum(a, dtype=torch.bool) # True (1) mean = total / count # 0.5 ``` This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed. Pull Request resolved: pytorch#139999 Approved by: https://github.com/cpuhrsch
1 parent 60a5050 commit a5051a9

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4882,7 +4882,6 @@ def fn(x):
48824882
@requires_vectorization
48834883
def test_bool_reduction_vec(self):
48844884
for op in (
4885-
torch.masked.mean,
48864885
torch.any,
48874886
torch.min,
48884887
torch.max,

torch/masked/_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,8 +1384,16 @@ def mean(
13841384
{reduction_args}
13851385
13861386
{reduction_example}"""
1387+
dtype_source = "Optional"
13871388
if dtype is None:
13881389
dtype = input.dtype
1390+
dtype_source = "Input"
1391+
1392+
if not (dtype.is_floating_point or dtype.is_complex):
1393+
raise ValueError(
1394+
f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
1395+
f"a floating point or complex dtype. Got: {dtype}"
1396+
)
13891397
if input.layout == torch.strided:
13901398
if mask is None:
13911399
# TODO: compute count analytically

torch/testing/_internal/opinfo/definitions/_masked.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -769,26 +769,8 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
769769
supports_forward_ad=True,
770770
supports_fwgrad_bwgrad=True,
771771
promotes_int_to_float=True,
772-
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
772+
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
773773
skips=(
774-
DecorateInfo(
775-
unittest.expectedFailure,
776-
"TestReductions",
777-
"test_ref_duplicate_values",
778-
dtypes=(torch.bool,),
779-
),
780-
DecorateInfo(
781-
unittest.expectedFailure,
782-
"TestReductions",
783-
"test_reference_masked",
784-
dtypes=(torch.bool,),
785-
),
786-
DecorateInfo(
787-
unittest.expectedFailure,
788-
"TestReductions",
789-
"test_ref_small_input",
790-
dtypes=(torch.bool,),
791-
),
792774
DecorateInfo(
793775
unittest.expectedFailure,
794776
"TestNormalizeOperators",

0 commit comments

Comments
 (0)