Skip to content

Commit a2b0b26

Browse files
bdhirshpytorchmergebot
authored andcommitted
inductor codecache: include private inductor configs in cache key (pytorch#153672)
Fixes pytorch/torchtitan#1185 It looks like inductor's logic to include inductor configs in the cache key skips configs with a leading underscore by default. This came up in torchtitan - there's an asyncTP pipelining pass in inductor gated by a private config, and by not caching on the config we were attempting to use asyncTP when we shouldn't be. I'm not sure how worried we should be on the blast radius of this change. On the one hand: (1) it technically fixes any silent correctness issues in the cache around any other private inductor configs (it looks like there are a few) (2) there is some risk that there are some "harmless" configs that we are now including in the key, which may increase false negatives. I do see that there is an explicit list for "configs we want to ignore for caching" (`_save_config_ignore`), so my hope is that all harmless configs are already encapsulated there. Pull Request resolved: pytorch#153672 Approved by: https://github.com/oulgen
1 parent 5264f8c commit a2b0b26

File tree

5 files changed

+167
-4
lines changed

5 files changed

+167
-4
lines changed

test/inductor/test_codecache.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Owner(s): ["module: inductor"]
22
import functools
3+
import logging
34
import os
45
import pickle
56
import shutil
67
import subprocess
78
import sys
89
import tempfile
910
import unittest
11+
from contextlib import contextmanager
1012
from typing import Optional, Union
1113
from typing_extensions import override
1214
from unittest import mock
@@ -68,6 +70,36 @@
6870
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
6971

7072

73+
class LogCaptureHandler(logging.Handler):
74+
def __init__(self, level):
75+
super().__init__(level)
76+
self.records = []
77+
78+
def emit(self, record):
79+
self.records.append(record)
80+
81+
82+
@contextmanager
83+
def capture_logs(log_name, log_level):
84+
try:
85+
logger = logging.getLogger(log_name)
86+
old_level = logger.level
87+
handler = logging.Handler()
88+
logger.setLevel(log_level)
89+
log_records = []
90+
91+
def emit(record):
92+
log_records.append(record)
93+
94+
handler.emit = emit
95+
logger.addHandler(handler)
96+
97+
yield log_records
98+
finally:
99+
logger.removeHandler(handler)
100+
logger.setLevel(old_level)
101+
102+
71103
class MyModelConv2d(torch.nn.Module):
72104
def __init__(self, dim=512):
73105
super().__init__()
@@ -2147,6 +2179,91 @@ def test_hash_config_changes(self):
21472179
pickler.dumps(details3),
21482180
)
21492181

2182+
def test_hash_private_config_changes(self):
2183+
"""
2184+
Test that private config settings affect hashes.
2185+
"""
2186+
with config.patch({"_micro_pipeline_tp": False}):
2187+
details1 = FxGraphHashDetails(None, [], {}, [])
2188+
details2 = FxGraphHashDetails(None, [], {}, [])
2189+
2190+
with config.patch({"_micro_pipeline_tp": True}):
2191+
details3 = FxGraphHashDetails(None, [], {}, [])
2192+
2193+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
2194+
pickler = FxGraphCachePickler(gm)
2195+
2196+
self.assertEqual(
2197+
pickler.dumps(details1),
2198+
pickler.dumps(details2),
2199+
)
2200+
self.assertNotEqual(
2201+
pickler.dumps(details1),
2202+
pickler.dumps(details3),
2203+
)
2204+
2205+
def test_non_serializable_custom_passes_causes_cache_miss(self):
2206+
class Mod(torch.nn.Module):
2207+
def __init__(self) -> None:
2208+
super().__init__()
2209+
self.param = torch.nn.Parameter(torch.rand(4, 4))
2210+
2211+
def forward(self, x):
2212+
return x @ self.param
2213+
2214+
mod1 = Mod()
2215+
mod_compiled = torch.compile(mod1)
2216+
with torch.no_grad():
2217+
x = torch.rand(4, 4)
2218+
# miss
2219+
mod_compiled(x)
2220+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
2221+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
2222+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
2223+
# hit
2224+
torch._dynamo.reset()
2225+
mod_compiled(x)
2226+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
2227+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
2228+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
2229+
torch._dynamo.reset()
2230+
counters.clear()
2231+
2232+
# hit
2233+
mod_compiled(x)
2234+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
2235+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
2236+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
2237+
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
2238+
# miss (private config changed)
2239+
torch._dynamo.reset()
2240+
mod_compiled(x)
2241+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
2242+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
2243+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
2244+
torch._dynamo.reset()
2245+
counters.clear()
2246+
2247+
with capture_logs(
2248+
"torch._inductor.codecache", logging.INFO
2249+
) as logs, config.patch(
2250+
{"_fuse_ddp_communication_passes": [lambda *args: None]}
2251+
):
2252+
# bypass (custom pass is not serializable)
2253+
mod_compiled(x)
2254+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
2255+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
2256+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
2257+
counters.clear()
2258+
# assert that our bypass is explicit
2259+
self.assertTrue(
2260+
any(
2261+
x.getMessage()
2262+
== "Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
2263+
for x in logs
2264+
)
2265+
)
2266+
21502267
def test_hash_custom_passes(self):
21512268
"""
21522269
Test CustomGraphPass usage.

torch/_inductor/codecache.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,14 +882,44 @@ def __init__(
882882
# Also hash on various system info (including the triton compiler version).
883883
self.torch_version = torch_key()
884884
self.system_info = CacheBase.get_system()
885-
self.inductor_config = config.save_config_portable()
885+
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
886886
# Custom post grad passes should provide an ID to hash.
887887
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
888888
config.post_grad_custom_pre_pass
889889
)
890890
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
891891
config.post_grad_custom_post_pass
892892
)
893+
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
894+
config._pre_fusion_custom_pass
895+
)
896+
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
897+
config._fuse_ddp_communication_passes
898+
)
899+
900+
# This is mainly added to handle these two inductor configs, which are (unfortunately)
901+
# sometimes cache safe:
902+
# - _pre_fusion_custom_pass
903+
# - _fuse_ddp_communication_passes
904+
# Their types can be found in `torch/_inductor/config.py`, but:
905+
# - if they are string names, we can cache them safely (one is by default)
906+
# - if any of them are set to custom callables, we will need to cache miss
907+
# Future work is for someone to find any places where these functions are used
908+
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
909+
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
910+
if not custom_pass:
911+
return None
912+
if isinstance(custom_pass, list):
913+
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
914+
if isinstance(custom_pass, str):
915+
return custom_pass
916+
if isinstance(custom_pass, CustomGraphPass):
917+
return custom_pass.uuid()
918+
if callable(custom_pass):
919+
# Returning None is safe here because we raise an explicit bypass error
920+
# later if we detect these passes are set to callables
921+
return None
922+
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
893923

894924
def _get_custom_pass_detail(
895925
self, custom_pass: CustomGraphPassType
@@ -1367,6 +1397,14 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
13671397
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
13681398
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
13691399
raise BypassFxGraphCache("Unsupported post grad custom pass")
1400+
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
1401+
# and ensure they are not passing us raw callables
1402+
if config._pre_fusion_custom_pass is not None:
1403+
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
1404+
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
1405+
for p in config._fuse_ddp_communication_passes:
1406+
if callable(p) and not isinstance(p, CustomGraphPass):
1407+
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
13701408

13711409
# Freezing can embed constants that wouldn't be static across runs.
13721410
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,8 @@ class trace:
16671667
"aot_inductor.dump_aoti_minifier",
16681668
"post_grad_custom_pre_pass",
16691669
"post_grad_custom_post_pass",
1670+
"_fuse_ddp_communication_passes",
1671+
"_pre_fusion_custom_pass",
16701672
]
16711673

16721674
_cache_config_ignore_prefix: list[str] = [
@@ -1680,6 +1682,8 @@ class trace:
16801682
# see CustomGraphPass; these are handled specially
16811683
"post_grad_custom_post_pass",
16821684
"post_grad_custom_pre_pass",
1685+
"_fuse_ddp_communication_passes",
1686+
"_pre_fusion_custom_pass",
16831687
# tests assume that changes here don't invalidate cache
16841688
"always_complex_memory_overlap_TESTING_ONLY",
16851689
]

torch/utils/_config_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,13 @@ def save_config(self) -> bytes:
508508
protocol=2,
509509
)
510510

511-
def save_config_portable(self) -> dict[str, Any]:
511+
def save_config_portable(
512+
self, *, ignore_private_configs: bool = True
513+
) -> dict[str, Any]:
512514
"""Convert config to portable format"""
513-
prefixes = ["_"]
515+
prefixes = []
516+
if ignore_private_configs:
517+
prefixes.append("_")
514518
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
515519
return self._get_dict(ignored_prefixes=prefixes)
516520

torch/utils/_config_typing.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
2424
assert TYPE_CHECKING, "Do not use at runtime"
2525

2626
def save_config() -> bytes: ...
27-
def save_config_portable() -> dict[str, Any]: ...
27+
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
2828
def codegen_config() -> str: ...
2929
def get_hash() -> bytes: ...
3030
def to_dict() -> dict[str, Any]: ...

0 commit comments

Comments
 (0)