Commit 5215885
re-use FloorDiv for RShift (pytorch#145898)
I encountered this C++ compilation error.
```
579 | int64_t var_6 = (static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))) | std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))));
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| | |
| int64_t {aka long int} double
```
Then, I figured out where this std::floor came from with the help of Bob's guard provenance tool. It comes from RShift which is used in `triton.next_power_of_2`.
---
Before, we used `std::floor`
```
int64_t var_6 = (
static_cast<int64_t>(std::floor((1.0/2.0)*u0)) |
static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0)))))
| std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) # no cast to int here.
| static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))));
```
Now, we use `c10::div_floor_integer` instead
```
int64_t var_6 = (
(c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L))) |
(c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))) |
(c10::div_floor_integer(static_cast<int64_t>((c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L)))
| (c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)));
```
Pull Request resolved: pytorch#145898
Approved by: https://github.com/desertfire, https://github.com/bobrenjc93
ghstack dependencies: pytorch#1458021 parent 3df961d commit 5215885
File tree
2 files changed
+28
-1
lines changed- test/inductor
- torch/utils/_sympy
2 files changed
+28
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2176 | 2176 | | |
2177 | 2177 | | |
2178 | 2178 | | |
| 2179 | + | |
| 2180 | + | |
| 2181 | + | |
| 2182 | + | |
| 2183 | + | |
| 2184 | + | |
| 2185 | + | |
| 2186 | + | |
| 2187 | + | |
| 2188 | + | |
| 2189 | + | |
| 2190 | + | |
| 2191 | + | |
| 2192 | + | |
| 2193 | + | |
| 2194 | + | |
| 2195 | + | |
| 2196 | + | |
| 2197 | + | |
| 2198 | + | |
| 2199 | + | |
| 2200 | + | |
| 2201 | + | |
| 2202 | + | |
| 2203 | + | |
| 2204 | + | |
| 2205 | + | |
2179 | 2206 | | |
2180 | 2207 | | |
2181 | 2208 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
578 | 578 | | |
579 | 579 | | |
580 | 580 | | |
581 | | - | |
| 581 | + | |
582 | 582 | | |
583 | 583 | | |
584 | 584 | | |
| |||
0 commit comments