Skip to content

Commit 7dcc77e

Browse files
eellisonpytorchmergebot
authored andcommitted
Turn on new tiling by default (pytorch#154768)
Turning on in fbcode to come. Also updates `max_tiles` to have a default value of None. The existing tiling logic doesn't really handle max_tiles=3 well, but we do in the new tiling logic, so we default to 3 in the new logic and 2 elsewhere unless max_tiles has been explicitly set. TB runners have been very unstable recently (do we need to bump batch size ?) but e.g. for a [recent torchbench](https://hud.pytorch.org/benchmark/torchbench/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Tue,%2027%20May%202025%2015:38:26%20GMT&stopTime=Tue,%2003%20Jun%202025%2015:38:26%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/803/head&lCommit=8480c220db4eb3c9e2b58d85a698d0a7113a6e37&rBranch=main&rCommit=0cd18ba1ca35d87916723d445c06664615dcae12) inference run we had 15 models with a lower execution time (i.g. green) and 2 models with higher (i.e.. red) I am doing another run and will update here. Dynamic shapes is not yet turned on because there are a lot of fixes to be done in splitting that don't work yet.. See: ``` (Pdb) p expr ((s25*s85)//32) (Pdb) p FloorDiv(expr, expr) ((s25*s85)//(32*(((s25*s85)//32)))) ``` and also - unbacked shape is not multiple of itself. Pull Request resolved: pytorch#154768 Approved by: https://github.com/jansel
1 parent a85ad55 commit 7dcc77e

File tree

5 files changed

+51
-22
lines changed

5 files changed

+51
-22
lines changed

test/inductor/test_loop_ordering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def f(x):
526526
"triton.unique_kernel_names": True,
527527
"loop_ordering_after_fusion": True,
528528
"triton.max_tiles": 3,
529-
"test_configs.global_tiling_analysis": True,
529+
"triton.coalesce_tiling_analysis": True,
530530
}
531531
)
532532
@instantiate_parametrized_tests
@@ -798,13 +798,14 @@ def fn(nodes):
798798
# coalesce twice as many bytes as first dimension
799799
# if not downcasted
800800
# if downcasted, should be equal, bc larger dtype size
801+
# we also weight writes x 2
801802
cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]]
802803
t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]]
803804

804805
if not downcast_transposed_v:
805-
self.assertEqual(cont_reads, t_reads * 2)
806+
self.assertEqual(cont_reads, t_reads * 3)
806807
else:
807-
self.assertEqual(cont_reads, t_reads)
808+
self.assertEqual(cont_reads, t_reads * 1.5)
808809

809810
return nodes
810811

@@ -908,8 +909,7 @@ def forward(permute):
908909
{
909910
"triton.unique_kernel_names": True,
910911
"loop_ordering_after_fusion": True,
911-
"test_configs.global_tiling_analysis": True,
912-
"triton.max_tiles": 3,
912+
"triton.coalesce_tiling_analysis": True,
913913
}
914914
)
915915
@instantiate_parametrized_tests

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14133,6 +14133,8 @@ def f(x, mask):
1413314133
# it does not move the tensor constructor to cuda and keeps it on CPU.
1413414134
self.assertFalse("empty_strided_cuda(()" in code)
1413514135

14136+
# only uncoalesced without this :)
14137+
@config.patch("triton.coalesce_tiling_analysis", False)
1413614138
@config.patch("triton.use_block_ptr", False)
1413714139
def test_evict_last_non_coalesced_loads(self):
1413814140
@torch.compile
@@ -14183,6 +14185,7 @@ def f(a, b):
1418314185
)
1418414186

1418514187
@config.patch("triton.use_block_ptr", True)
14188+
@config.patch("triton.coalesce_tiling_analysis", False)
1418614189
def test_evict_last_non_coalesced_loads_block_ptr(self):
1418714190
@torch.compile
1418814191
def f(a, b):

torch/_inductor/codegen/simd.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@
8686
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
8787

8888

89+
def get_max_tiles(default: int = 2) -> int:
90+
max_tiles = torch._inductor.config.triton.max_tiles
91+
return max_tiles if max_tiles is not None else default
92+
93+
8994
@dataclasses.dataclass
9095
class IterationRanges:
9196
"""
@@ -1354,7 +1359,7 @@ def codegen_node(
13541359

13551360
nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]
13561361

1357-
if torch._inductor.config.test_configs.global_tiling_analysis:
1362+
if torch._inductor.config.triton.coalesce_tiling_analysis:
13581363
coalesce_analysis = analyze_memory_coalescing(node)
13591364
else:
13601365
coalesce_analysis = None
@@ -1993,7 +1998,7 @@ def get_nd_tilings(
19931998

19941999
# Flatten leading dimensions, assigning labels to each dim.
19952000
for node_tiling in node_tilings:
1996-
num_leading_dims = max(0, len(node_tiling) - config.triton.max_tiles)
2001+
num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2))
19972002
first_trailing_dim = num_leading_dims + 1
19982003
collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim])
19992004
collapsed_splits = (collapsed_leading_dim,) + tuple(
@@ -2165,7 +2170,7 @@ def process_node_vars(
21652170
)
21662171
)
21672172

2168-
if torch._inductor.config.triton.max_tiles == 3 and reduction_numel == 1:
2173+
if get_max_tiles(default=3) == 3 and reduction_numel == 1:
21692174
for vars_to_use in itertools.combinations(overlapping_iter_vars, 2):
21702175
score_split.append(
21712176
(
@@ -2187,13 +2192,16 @@ def process_node_vars(
21872192

21882193
# add a slight penalty for longer tilings that dont increase score much,
21892194
# and are poor sizes
2190-
additional_tiling_penalty = 1.025
2195+
bad_size_additional_tiling_penalty = 1.025
2196+
good_size_tiling_penalty = 1.005
21912197

21922198
def score_mod(t):
21932199
score_factor = 1.0
21942200
for tile_size in t[0].tiling.values():
21952201
if not CandidateTiling.is_good_size(tile_size):
2196-
score_factor = score_factor / additional_tiling_penalty
2202+
score_factor = score_factor / bad_size_additional_tiling_penalty
2203+
else:
2204+
score_factor = score_factor / good_size_tiling_penalty
21972205

21982206
return -t[0].score * score_factor
21992207

@@ -2204,7 +2212,7 @@ def score_mod(t):
22042212
):
22052213
# we always include default reduction numel == 1, dont include
22062214
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
2207-
if tiling_len > torch._inductor.config.triton.max_tiles:
2215+
if tiling_len > get_max_tiles(default=3):
22082216
perf_hint_log.info(
22092217
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
22102218
"set to %s. Consider increasing",
@@ -2289,16 +2297,17 @@ def get_tiling_and_scores(
22892297

22902298
# # TODO: enable by default
22912299
if (
2292-
torch._inductor.config.test_configs.global_tiling_analysis
2300+
torch._inductor.config.triton.coalesce_tiling_analysis
22932301
and coalesce_analysis
2302+
and not config.triton.prefer_nd_tiling
22942303
):
22952304
return cls.compute_tiling_strategy(
22962305
node_schedule, numel, reduction_numel, coalesce_analysis
22972306
)
22982307

2299-
if (
2300-
not is_pointwise and not config.triton.tile_reductions
2301-
) or config.triton.max_tiles <= 1:
2308+
if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles(
2309+
default=2
2310+
) <= 1:
23022311
# Emit a perf hint in case we miss an opportunity to tile a reduction.
23032312
if perf_hint_log.level <= logging.WARNING:
23042313
for node in EnableReduction.filter(node_schedule):
@@ -2333,7 +2342,7 @@ def get_tiling_and_scores(
23332342
for candidate_tiling, score in candidate_tiles.most_common()
23342343
]
23352344

2336-
if config.triton.max_tiles >= 3 and is_pointwise:
2345+
if get_max_tiles(default=2) >= 3 and is_pointwise:
23372346
# Consider adding a third dimension of tiling, but only
23382347
# when a1 is a multiple of b1; otherwise, you have a lot
23392348
# of stragglers which is annoying to generate code for.

torch/_inductor/config.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,12 +1115,23 @@ class triton:
11151115
# Always load full blocks (rather than broadcasting inside the block)
11161116
dense_indexing = False
11171117

1118+
# TODO - enable by default
1119+
coalesce_tiling_analysis: bool = (
1120+
os.environ.get(
1121+
"TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
1122+
)
1123+
== "1"
1124+
)
1125+
11181126
# limit tiling dimensions
11191127
# - max_tiles=1 disables tiling
1120-
# - max_tiles=2 is the default
1128+
# - max_tiles=2
11211129
# - max_tiles=3 is experimental and may have bugs
11221130
# higher values are unsupported
1123-
max_tiles = 2
1131+
1132+
# We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
1133+
# Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
1134+
max_tiles: Optional[int] = None
11241135

11251136
# Prefer higher dimensional tilings. This simplifies indexing expressions, making
11261137
# it easier to identify block pointers.
@@ -1681,9 +1692,6 @@ class test_configs:
16811692

16821693
graphsafe_rng_func_ignores_fallback_random = False
16831694

1684-
# TODO - temporary config before enabled by default
1685-
global_tiling_analysis: bool = False
1686-
16871695

16881696
if TYPE_CHECKING:
16891697
from torch.utils._config_typing import * # noqa: F401, F403

torch/_inductor/tiling_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,9 @@ class VarTiling:
621621

622622
@dataclasses.dataclass(frozen=True)
623623
class CoalesceVarAnalysis:
624+
# Var -> Memory Score - not strictly the amount of memory
625+
# because we multiply writes x2
626+
# TODO: separate into dataclass that olds mem, dtype, is_write
624627
coalesced_by_var: dict[sympy.Expr, int]
625628

626629
norm_read_writes: FusedNormalizedReadsWrites
@@ -656,7 +659,10 @@ def analyze_memory_coalescing(
656659
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
657660
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
658661

659-
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
662+
for is_read, (memory_expr, buf_names) in itertools.chain(
663+
((True, item) for item in reads.items()),
664+
((False, item) for item in writes.items()),
665+
):
660666
# skip memory deps with indirect vars - todo: better handling
661667
indirect_expr = bool(
662668
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
@@ -676,6 +682,9 @@ def analyze_memory_coalescing(
676682
if buf := V.graph.try_get_buffer(buf_name):
677683
byte_multipler += buf.dtype.itemsize
678684

685+
# coalesced writes more important
686+
byte_multipler *= 1 if is_read else 2
687+
679688
if maybe_coalesced_var:
680689
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
681690
else:

0 commit comments

Comments
 (0)