Skip to content

Commit 93d7f90

Browse files
SamGinzburgpytorchmergebot
authored andcommitted
[inductor] getting AOT inductor to treat None args correctly (pytorch#139114)
Differential Revision: [D65102228](https://our.internmc.facebook.com/intern/diff/D65102228) Pull Request resolved: pytorch#139114 Approved by: https://github.com/aakhundov
1 parent 8b08559 commit 93d7f90

File tree

4 files changed

+98
-7
lines changed

4 files changed

+98
-7
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
if HAS_CUDA:
5252
import triton # @manual
53+
from triton import language as tl
5354

5455
from torch.testing._internal.triton_utils import (
5556
add_kernel,
@@ -3816,6 +3817,56 @@ def forward(self, x):
38163817

38173818
self.check_model(Model(), example_inputs)
38183819

3820+
def test_none_args_aot_codegen(self):
3821+
if self.device != "cuda":
3822+
raise unittest.SkipTest("requires CUDA")
3823+
3824+
@triton.autotune(
3825+
configs=[
3826+
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
3827+
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
3828+
],
3829+
key=["n_elements"],
3830+
)
3831+
@triton.jit
3832+
def sin_kernel(
3833+
in_ptr0,
3834+
out_ptr,
3835+
# We want to include an arg known to be 1 at compile time
3836+
# This is because we remove None args from the arg list; changing the eq_1/constexpr arg indices.
3837+
# We want to make sure we recompute these correctly
3838+
EQ_1_ARG,
3839+
n_elements,
3840+
BLOCK_SIZE: "tl.constexpr",
3841+
):
3842+
pid = tl.program_id(axis=0)
3843+
block_start = pid * BLOCK_SIZE
3844+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
3845+
mask = offsets < n_elements
3846+
if in_ptr0 is not None:
3847+
x = tl.load(in_ptr0 + offsets, mask=mask)
3848+
else:
3849+
x = 0.0
3850+
output = tl.sin(x) + EQ_1_ARG
3851+
tl.store(out_ptr + offsets, output, mask=mask)
3852+
3853+
def sin_triton(x, out):
3854+
n_elements = out.numel()
3855+
sin_kernel[(n_elements,)](x, out, 1, n_elements)
3856+
return out
3857+
3858+
x = torch.randn(65, device=self.device)
3859+
out = torch.empty_like(x)
3860+
3861+
not_none_inputs = (x, out)
3862+
none_inputs = (None, out)
3863+
3864+
# AOTI compilation specializes on either None or non-None inputs
3865+
# So we have to check twice here
3866+
3867+
self.check_model(sin_triton, none_inputs)
3868+
self.check_model(sin_triton, not_none_inputs)
3869+
38193870

38203871
class AOTInductorLoggingTest(LoggingTestCase):
38213872
@make_logging_test(dynamic=logging.DEBUG)

torch/_inductor/codegen/wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,11 @@ def generate_user_defined_triton_kernel(
844844
for arg in raw_args
845845
]
846846
self.generate_kernel_call(
847-
kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args
847+
kernel_name,
848+
args,
849+
grid_fn=grid_fn,
850+
arg_types=arg_types,
851+
raw_args=raw_args,
848852
)
849853

850854
def _generate_tma_descriptor_call(self, desc, apply_size_hints=False):

torch/_inductor/ir.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5416,17 +5416,50 @@ def codegen(self, wrapper):
54165416
2. The arg is not tl.constexpr so we have to remove it
54175417
"""
54185418
constexpr_indices_set = set(constexpr_indices)
5419+
REMOVED = object()
54195420
raw_args = [
5420-
arg
5421-
for idx, arg in enumerate(raw_args)
5421+
(idx, arg)
54225422
if (arg is not None) or (arg is None and idx in constexpr_indices_set)
5423+
else (idx, REMOVED)
5424+
for idx, arg in enumerate(raw_args)
54235425
]
5426+
removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
5427+
raw_args = [val for idx, val in raw_args if val != REMOVED]
5428+
5429+
# We have to compute the constexpr indices for the new, filtered raw_args
5430+
# We also have to adjust equal_to_1.
5431+
if removed_none_args:
5432+
eq1_indices_set = set(triton_meta["configs"][0].equal_to_1)
5433+
constexpr_indices = []
5434+
equal_to_1 = []
5435+
index_shift = 0
5436+
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
5437+
# every time we encounter an idx we removed, adjust by one to account for it
5438+
# So for example if we had [None, const X]
5439+
# iter 1:
5440+
# None was removed, adjust=1
5441+
# iter 2:
5442+
# X is const at idx=1, but the adjusted idx is 0 now, because None was removed
5443+
if idx in removed_none_args:
5444+
index_shift += 1
5445+
continue
5446+
arg_index = kernel.arg_names.index(kwarg)
5447+
if arg_index in kernel.constexprs:
5448+
constexpr_indices.append(idx - index_shift)
5449+
if arg_index in eq1_indices_set:
5450+
equal_to_1.append(idx - index_shift)
5451+
5452+
triton_meta["configs"][0].equal_to_1 = equal_to_1
54245453

54255454
# Call to kernel
54265455
self.codegen_comment(wrapper)
5427-
54285456
wrapper.generate_user_defined_triton_kernel(
5429-
new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices
5457+
new_name,
5458+
raw_args,
5459+
self.grid,
5460+
configs,
5461+
triton_meta,
5462+
constexpr_indices,
54305463
)
54315464

54325465
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,14 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool):
501501
so we use self.fn.constexprs instead.
502502
3. It isn't in the compile_meta signature
503503
"""
504-
none_args = set(compile_meta["constants"].keys())
505504
known_constants = {
506505
arg for i, arg in enumerate(self.fn.arg_names) if i in self.fn.constexprs
507506
}
508-
none_args = none_args.difference(known_constants)
507+
none_args = {
508+
k
509+
for k, v in compile_meta["constants"].items()
510+
if v is None and k not in known_constants
511+
}
509512
none_args = none_args.difference(set(compile_meta["signature"].keys()))
510513

511514
call_args = [

0 commit comments

Comments
 (0)