Skip to content

Commit 5215885

Browse files
ColinPepplerpytorchmergebot
authored andcommitted
re-use FloorDiv for RShift (pytorch#145898)
I encountered this C++ compilation error. ``` 579 | int64_t var_6 = (static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))) | std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0)))))); | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | | | | int64_t {aka long int} double ``` Then, I figured out where this std::floor came from with the help of Bob's guard provenance tool. It comes from RShift which is used in `triton.next_power_of_2`. --- Before, we used `std::floor` ``` int64_t var_6 = ( static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))) | std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) # no cast to int here. | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0)))))); ``` Now, we use `c10::div_floor_integer` instead ``` int64_t var_6 = ( (c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L))) | (c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))) | (c10::div_floor_integer(static_cast<int64_t>((c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L))) | (c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))), static_cast<int64_t>(16L))); ``` Pull Request resolved: pytorch#145898 Approved by: https://github.com/desertfire, https://github.com/bobrenjc93 ghstack dependencies: pytorch#145802
1 parent 3df961d commit 5215885

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,6 +2176,33 @@ def forward(self, x):
21762176
)
21772177
self.check_model(model, example_inputs)
21782178

2179+
def test_triton_next_power_of_2(self):
2180+
if self.device != GPU_TYPE:
2181+
raise unittest.SkipTest("requires GPU")
2182+
2183+
class Model(torch.nn.Module):
2184+
def forward(self, a, b, lengths):
2185+
n_elements = a.numel()
2186+
out = torch.empty_like(a)
2187+
max_len = int(lengths.max())
2188+
scaling_factor = triton.next_power_of_2(max_len)
2189+
add_kernel_with_scaling[(n_elements,)](
2190+
a,
2191+
b,
2192+
out,
2193+
n_elements,
2194+
scaling_factor,
2195+
BLOCK_SIZE=16,
2196+
)
2197+
return out
2198+
2199+
example_inputs = (
2200+
torch.randn(2, device=self.device),
2201+
torch.randn(2, device=self.device),
2202+
torch.arange(end=4, device=self.device),
2203+
)
2204+
self.check_model(Model(), example_inputs)
2205+
21792206
@common_utils.parametrize("grid_type", [1, 2, 3])
21802207
@common_utils.parametrize("num_dims", [1, 2])
21812208
@common_utils.parametrize("dynamic", [False, True])

torch/utils/_sympy/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ class RShift(sympy.Function):
578578
def eval(cls, base, shift):
579579
if shift < 0:
580580
raise ValueError("negative shift count")
581-
return base // 2**shift
581+
return FloorDiv(base, 2**shift)
582582

583583

584584
class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]

0 commit comments

Comments
 (0)