Commit a5051a9
Update torch.masked.mean to upcast dtype for bool tensors (pytorch#139999)
When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated.
The below example shows how the incorrect result occurs:
```
a = torch.tensor([True, True])
count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2
total = torch.sum(a, dtype=torch.bool) # True (1)
mean = total / count # 0.5
```
This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed.
Pull Request resolved: pytorch#139999
Approved by: https://github.com/cpuhrsch1 parent 60a5050 commit a5051a9
File tree
3 files changed
+9
-20
lines changed- test/inductor
- torch
- masked
- testing/_internal/opinfo/definitions
3 files changed
+9
-20
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4882 | 4882 | | |
4883 | 4883 | | |
4884 | 4884 | | |
4885 | | - | |
4886 | 4885 | | |
4887 | 4886 | | |
4888 | 4887 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1384 | 1384 | | |
1385 | 1385 | | |
1386 | 1386 | | |
| 1387 | + | |
1387 | 1388 | | |
1388 | 1389 | | |
| 1390 | + | |
| 1391 | + | |
| 1392 | + | |
| 1393 | + | |
| 1394 | + | |
| 1395 | + | |
| 1396 | + | |
1389 | 1397 | | |
1390 | 1398 | | |
1391 | 1399 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
769 | 769 | | |
770 | 770 | | |
771 | 771 | | |
772 | | - | |
| 772 | + | |
773 | 773 | | |
774 | | - | |
775 | | - | |
776 | | - | |
777 | | - | |
778 | | - | |
779 | | - | |
780 | | - | |
781 | | - | |
782 | | - | |
783 | | - | |
784 | | - | |
785 | | - | |
786 | | - | |
787 | | - | |
788 | | - | |
789 | | - | |
790 | | - | |
791 | | - | |
792 | 774 | | |
793 | 775 | | |
794 | 776 | | |
| |||
0 commit comments