Skip to content

Commit 13ea0f2

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][dynamic] Recompilation hint for nn module integer attributes (pytorch#154867)
For program like this ``` class Mod(torch.nn.Module): def __init__(self): super().__init__() self.c = 0 def forward(self, x): self.c += 1 return x * self.c ``` You can check the recompile reasons at https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpzv9z6Q/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ![image](https://github.com/user-attachments/assets/856a95fd-0533-4abc-a213-1f73ae2cb766) Pull Request resolved: pytorch#154867 Approved by: https://github.com/zou3519
1 parent a14f427 commit 13ea0f2

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

torch/_dynamo/guards.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ def metadata_checker(x):
16651665
metadata_checker, get_verbose_code_parts(global_name, guard)
16661666
)
16671667

1668-
def EQUALS_MATCH(self, guard: Guard):
1668+
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None):
16691669
ref = self.arg_ref(guard)
16701670
val = self.get(guard.name)
16711671
if np:
@@ -1762,9 +1762,14 @@ def EQUALS_MATCH(self, guard: Guard):
17621762
# is immutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the
17631763
# pointer equality check.
17641764
val = deepcopy(val)
1765-
self.get_guard_manager(guard).add_equals_match_guard(
1766-
val, get_verbose_code_parts(code, guard)
1767-
)
1765+
1766+
verbose_code_parts = get_verbose_code_parts(code, guard)
1767+
if recompile_hint:
1768+
verbose_code_parts = [
1769+
f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
1770+
]
1771+
1772+
self.get_guard_manager(guard).add_equals_match_guard(val, verbose_code_parts)
17681773
self._set_guard_export_info(guard, code)
17691774
return
17701775

torch/_dynamo/source.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dataclasses
2323
import enum
24+
import functools
2425
from typing import Any, Optional, TYPE_CHECKING, Union
2526

2627
from torch._guards import ChainedSource, GuardSource, Source
@@ -901,6 +902,7 @@ def is_from_source(source: Source, target: Source):
901902
return source == target
902903

903904

905+
@functools.lru_cache
904906
def is_from_unspecialized_nn_module_source(source: Source):
905907
if isinstance(source, UnspecializedNNModuleSource):
906908
return True
@@ -909,6 +911,16 @@ def is_from_unspecialized_nn_module_source(source: Source):
909911
return False
910912

911913

914+
@functools.lru_cache
915+
def is_from_unspecialized_builtin_nn_module_source(source: Source):
916+
if isinstance(source, UnspecializedBuiltinNNModuleSource):
917+
return True
918+
if isinstance(source, ChainedSource):
919+
return is_from_unspecialized_builtin_nn_module_source(source.base)
920+
return False
921+
922+
923+
@functools.lru_cache
912924
def is_from_unspecialized_param_buffer_source(source: Source):
913925
if isinstance(source, UnspecializedParamBufferSource):
914926
return True
@@ -917,6 +929,7 @@ def is_from_unspecialized_param_buffer_source(source: Source):
917929
return False
918930

919931

932+
@functools.lru_cache
920933
def is_from_flatten_script_object_source(source: Source):
921934
if isinstance(source, FlattenScriptObjectSource):
922935
return True
@@ -925,6 +938,7 @@ def is_from_flatten_script_object_source(source: Source):
925938
return False
926939

927940

941+
@functools.lru_cache
928942
def is_from_optimizer_source(source: Source):
929943
if isinstance(source, OptimizerSource):
930944
return True
@@ -935,6 +949,7 @@ def is_from_optimizer_source(source: Source):
935949

936950
# TODO: can probably write a generic "test this on everything in the chain"
937951
# helper
952+
@functools.lru_cache
938953
def is_from_defaults(source: Source):
939954
if isinstance(source, DefaultsSource):
940955
return True

torch/_dynamo/variables/builder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,28 @@ def wrap_literal(self, value):
18071807
# unspecializing int by default, but still
18081808
# specialize for the following conditions
18091809
if is_int_specialization_case(value, self.source):
1810-
self.install_guards(GuardBuilder.CONSTANT_MATCH)
1810+
recompile_hint = None
1811+
if (
1812+
self.source.guard_source().is_unspecialized_builtin_nn_module()
1813+
or self.source.guard_source().is_unspecialized_nn_module()
1814+
):
1815+
# This means that it is an integer from a NN module.
1816+
# Dynamo considers nn module int attributes to be static
1817+
# (a good heursitic). But a user might want to mark the
1818+
# int attribute to be a symint, so track this integer
1819+
# for recompilation later.
1820+
recompile_hint = (
1821+
"torch.compile considers integer attributes of the nn.Module to be static. "
1822+
"If you are observing recompilation, you might want to make this integer dynamic "
1823+
"using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this "
1824+
"integer into a tensor."
1825+
)
1826+
1827+
self.install_guards(
1828+
functools.partial(
1829+
GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint
1830+
)
1831+
)
18111832
return ConstantVariable.create(value=value, source=self.source)
18121833

18131834
return self.wrap_symint(value)

0 commit comments

Comments
 (0)