Commit dd41a39
[MPS] Fix unary/binary ops for 2**32+ elem tensors (pytorch#155183)
By using `TensorIterator::with_32bit_indexing()` primitive
Add `bind_tensors` helper function that correctly sets up MPS tensors originating from TensorIterator
TODO: Add comments to bind_tensors as well asunit test, based on
```
python -c "import torch;print((torch.rand(1, 1024, 1024, dtype=torch.bfloat16, device='mps') + torch.rand(5000, 1, 1, dtype=torch.bfloat16, device='mps')).sin())"
```
Fixes pytorch#154828
Pull Request resolved: pytorch#155183
Approved by: https://github.com/cyyever, https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: pytorch#155150, pytorch#155178, pytorch#1551841 parent 05dd638 commit dd41a39
File tree
3 files changed
+50
-4
lines changed- aten/src/ATen/native/mps
- test
3 files changed
+50
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
| 157 | + | |
157 | 158 | | |
158 | 159 | | |
159 | 160 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
971 | 971 | | |
972 | 972 | | |
973 | 973 | | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
974 | 991 | | |
975 | 992 | | |
976 | 993 | | |
977 | | - | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
| 1000 | + | |
| 1001 | + | |
978 | 1002 | | |
979 | 1003 | | |
980 | 1004 | | |
| |||
997 | 1021 | | |
998 | 1022 | | |
999 | 1023 | | |
1000 | | - | |
| 1024 | + | |
1001 | 1025 | | |
1002 | 1026 | | |
1003 | 1027 | | |
| |||
1022 | 1046 | | |
1023 | 1047 | | |
1024 | 1048 | | |
1025 | | - | |
1026 | 1049 | | |
1027 | 1050 | | |
1028 | 1051 | | |
1029 | 1052 | | |
1030 | 1053 | | |
1031 | 1054 | | |
| 1055 | + | |
| 1056 | + | |
| 1057 | + | |
| 1058 | + | |
| 1059 | + | |
| 1060 | + | |
| 1061 | + | |
| 1062 | + | |
1032 | 1063 | | |
1033 | 1064 | | |
1034 | 1065 | | |
| |||
1062 | 1093 | | |
1063 | 1094 | | |
1064 | 1095 | | |
1065 | | - | |
| 1096 | + | |
1066 | 1097 | | |
1067 | 1098 | | |
1068 | 1099 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7955 | 7955 | | |
7956 | 7956 | | |
7957 | 7957 | | |
| 7958 | + | |
| 7959 | + | |
| 7960 | + | |
| 7961 | + | |
| 7962 | + | |
| 7963 | + | |
| 7964 | + | |
| 7965 | + | |
| 7966 | + | |
| 7967 | + | |
| 7968 | + | |
| 7969 | + | |
| 7970 | + | |
| 7971 | + | |
7958 | 7972 | | |
7959 | 7973 | | |
7960 | 7974 | | |
| |||
0 commit comments