Skip to content

Commit 64cbaa8

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] Make class members go through obj.__class__.__dict__ (pytorch#159534)
Pull Request resolved: pytorch#159534 Approved by: https://github.com/jansel
1 parent 4516c59 commit 64cbaa8

File tree

9 files changed

+300
-27
lines changed

9 files changed

+300
-27
lines changed

test/dynamo/test_guard_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,9 @@ def hook(guard_wrapper, f_locals, builder):
880880
counter += 1
881881

882882
class Bar:
883-
x = 4
884-
y = torch.randn(4)
883+
def __init__(self):
884+
self.x = 4
885+
self.y = torch.randn(4)
885886

886887
bar = Bar()
887888

test/dynamo/test_skip_guard_eval_unsafe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def fn(x, y):
5454

5555
def test_post_recompile(self):
5656
class Foo:
57-
a = 4
58-
b = 5
57+
def __init__(self):
58+
self.a = 4
59+
self.b = 5
5960

6061
foo = Foo()
6162

torch/_C/_dynamo/guards.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class DictGuardManager(GuardManager):
139139
class GuardAccessor: ...
140140
class DictGetItemGuardAccessor(GuardAccessor): ...
141141
class GetGenericDictGuardAccessor(GuardAccessor): ...
142+
class TypeDictGuardAccessor(GuardAccessor): ...
143+
class TypeMROGuardAccessor(GuardAccessor): ...
142144

143145
def install_object_aliasing_guard(
144146
guard_managers: list[GuardManager],

torch/_dynamo/guards.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@
134134
TorchFunctionModeStackSource,
135135
TorchSource,
136136
TupleIteratorGetItemSource,
137+
TypeDictSource,
138+
TypeMROSource,
137139
TypeSource,
138140
UnspecializedBuiltinNNModuleSource,
139141
UnspecializedNNModuleSource,
@@ -864,6 +866,9 @@ def __init__(
864866
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
865867
"pytorch/compiler:guard_nn_modules"
866868
)
869+
self.already_guarded_not_present_in_generic_dict: OrderedSet[
870+
tuple[str, str]
871+
] = OrderedSet()
867872

868873
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
869874
dict_mgr = self.get_guard_manager(guard)
@@ -1211,6 +1216,20 @@ def get_guard_manager_from_source(self, source):
12111216
example_value=example_value,
12121217
guard_manager_enum=guard_manager_enum,
12131218
)
1219+
elif istype(source, TypeDictSource):
1220+
assert base_guard_manager # to make mypy happy
1221+
out = base_guard_manager.type_dict_manager(
1222+
source=source_name,
1223+
example_value=example_value,
1224+
guard_manager_enum=guard_manager_enum,
1225+
)
1226+
elif istype(source, TypeMROSource):
1227+
assert base_guard_manager # to make mypy happy
1228+
out = base_guard_manager.type_mro_manager(
1229+
source=source_name,
1230+
example_value=example_value,
1231+
guard_manager_enum=guard_manager_enum,
1232+
)
12141233
elif istype(
12151234
source,
12161235
(
@@ -1656,10 +1675,12 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
16561675
assert attr is not None
16571676
ref = self.arg_ref(guard)
16581677
val = self.get(guard.name)
1659-
assert isinstance(val, torch.nn.Module)
16601678

16611679
base_manager = self.get_guard_manager(guard)
16621680

1681+
if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
1682+
return
1683+
16631684
mod_dict_source = f"{guard.name}.__dict__"
16641685
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
16651686
source=mod_dict_source,
@@ -1671,6 +1692,7 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
16711692
mod_generic_dict_manager.add_dict_contains_guard(
16721693
False, attr, get_verbose_code_parts(code, guard)
16731694
)
1695+
self.already_guarded_not_present_in_generic_dict.add((ref, attr))
16741696

16751697
def TYPE_MATCH(self, guard: Guard) -> None:
16761698
# ___check_type_id is same as `id(type(x)) == y`

torch/_dynamo/source.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,38 @@ def name(self) -> str:
266266
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
267267

268268

269+
# Represents obj.__dict__ where obj is a type object
270+
@dataclasses.dataclass(frozen=True)
271+
class TypeDictSource(ChainedSource):
272+
def reconstruct(self, codegen: "PyCodegen") -> None:
273+
codegen(self.base)
274+
codegen.extend_output(codegen.create_load_attrs("__dict__"))
275+
276+
def guard_source(self) -> GuardSource:
277+
return self.base.guard_source()
278+
279+
def name(self) -> str:
280+
# type(ob).__dict__ can return a proxy of the dict. But in the C++
281+
# guard accessor, we are use type->tp_dict which is a dict. So,
282+
# forcefully pass a dict object to ensure that the GuardManager
283+
# registers that its working on a dict object.
284+
return f"dict({self.base.name()}.__dict__)"
285+
286+
287+
# Represents obj.__mro__ where object is type object
288+
@dataclasses.dataclass(frozen=True)
289+
class TypeMROSource(ChainedSource):
290+
def reconstruct(self, codegen: "PyCodegen") -> None:
291+
codegen(self.base)
292+
codegen.extend_output(codegen.create_load_attrs("__mro__"))
293+
294+
def guard_source(self) -> GuardSource:
295+
return self.base.guard_source()
296+
297+
def name(self) -> str:
298+
return f"{self.base.name()}.__mro__"
299+
300+
269301
@dataclasses.dataclass(frozen=True)
270302
class LocalCellSource(Source):
271303
"""

torch/_dynamo/variables/misc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
AttrSource,
4343
GenericAttrSource,
4444
GetItemSource,
45+
TypeMROSource,
4546
TypeSource,
4647
WeakRefCallSource,
4748
)
@@ -134,9 +135,7 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
134135
# Equivalent of something like type(L['self']).__mro__[1].attr_name
135136
if type_to_use_source:
136137
source = AttrSource(
137-
GetItemSource(
138-
AttrSource(type_to_use_source, "__mro__"), index
139-
),
138+
GetItemSource(TypeMROSource(type_to_use_source), index),
140139
name,
141140
)
142141
return resolved_getattr, source
@@ -247,7 +246,7 @@ def call_method(
247246
# different from type(self) with polymorphism.
248247
cls_source = None
249248
if self.objvar.source:
250-
cls_source = AttrSource(self.objvar.source, "__class__")
249+
cls_source = TypeSource(self.objvar.source)
251250
cls_variable = VariableTracker.build(
252251
tx, self.objvar.value_type, cls_source
253252
)

torch/_dynamo/variables/nn_module.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def call_function(
989989
fn = self.value_type.forward
990990

991991
if self.source:
992-
source = AttrSource(AttrSource(self.source, "__class__"), name)
992+
source = self.get_source_by_walking_mro(name)
993993
else:
994994
source = None
995995

@@ -1017,7 +1017,7 @@ def call_method(
10171017
if name in ["_call_impl", "_wrapped_call_impl"]:
10181018
fn = getattr(self.value_type, name)
10191019
if self.source:
1020-
source = AttrSource(AttrSource(self.source, "__class__"), name)
1020+
source = self.get_source_by_walking_mro(name)
10211021
else:
10221022
source = None
10231023

@@ -1032,9 +1032,7 @@ def call_method(
10321032
method = None
10331033

10341034
if isinstance(method, staticmethod):
1035-
source = AttrSource(
1036-
AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
1037-
)
1035+
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
10381036
return tx.inline_user_function_return(
10391037
variables.UserFunctionVariable(method.__func__, source=source),
10401038
args,

torch/_dynamo/variables/user_defined.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@
6060
AttrSource,
6161
CallFunctionNoArgsSource,
6262
DataclassFieldsSource,
63+
DictGetItemSource,
6364
GetItemSource,
6465
RandomValueSource,
66+
TypeDictSource,
67+
TypeMROSource,
6568
TypeSource,
6669
UnspecializedParamBufferSource,
6770
)
@@ -135,6 +138,14 @@ def is_forbidden_context_manager(ctx):
135138
return ctx in f_ctxs
136139

137140

141+
def is_cython_function(obj):
142+
return (
143+
callable(obj)
144+
and hasattr(type(obj), "__name__")
145+
and type(obj).__name__ == "cython_function_or_method"
146+
)
147+
148+
138149
class UserDefinedVariable(VariableTracker):
139150
value: object
140151

@@ -998,11 +1009,9 @@ def call_method(
9981009

9991010
# check for methods implemented in C++
10001011
if isinstance(method, types.FunctionType):
1001-
source = (
1002-
None
1003-
if self.source is None
1004-
else AttrSource(AttrSource(self.source, "__class__"), name)
1005-
)
1012+
source = None
1013+
if self.source:
1014+
source = self.get_source_by_walking_mro(name)
10061015
# TODO(jansel): add a guard to check for monkey patching?
10071016
from ..mutation_guard import unpatched_nn_module_init
10081017

@@ -1224,12 +1233,40 @@ def get_source_by_walking_mro(self, name):
12241233

12251234
for idx, klass in enumerate(type(self.value).__mro__):
12261235
if name in klass.__dict__:
1227-
mro_source = AttrSource(self.cls_source, "__mro__")
1228-
klass_source = GetItemSource(mro_source, idx)
1229-
dict_source = AttrSource(klass_source, "__dict__")
1230-
# TODO(anijain2305) - This is a mapping proxy object. Ideally we
1231-
# should use DictGetItemSource here.
1232-
return GetItemSource(dict_source, name)
1236+
if idx != 0:
1237+
mro_source = TypeMROSource(self.cls_source)
1238+
klass_source = GetItemSource(mro_source, idx)
1239+
else:
1240+
klass_source = self.cls_source
1241+
dict_source = TypeDictSource(klass_source)
1242+
out_source = DictGetItemSource(dict_source, name)
1243+
1244+
for absent_idx in range(1, idx):
1245+
# Insert a guard that the name is not present in the mro hierarchy
1246+
mro_source = TypeMROSource(self.cls_source)
1247+
klass_source = GetItemSource(mro_source, absent_idx)
1248+
dict_source = TypeDictSource(klass_source)
1249+
install_guard(
1250+
dict_source.make_guard(
1251+
functools.partial(
1252+
GuardBuilder.DICT_CONTAINS, key=name, invert=True
1253+
)
1254+
)
1255+
)
1256+
# Insert a guard that the name is not present in the object __dict__
1257+
if (
1258+
self.source
1259+
and hasattr(self.value, "__dict__")
1260+
and name not in self.value.__dict__
1261+
):
1262+
install_guard(
1263+
self.source.make_guard(
1264+
functools.partial(
1265+
GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name
1266+
)
1267+
)
1268+
)
1269+
return out_source
12331270

12341271
unimplemented_v2(
12351272
gb_type="could not find name in object's mro",
@@ -1339,10 +1376,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13391376
if subobj is torch.nn.Module.__init__:
13401377
subobj = unpatched_nn_module_init
13411378

1379+
subobj_from_class = inspect.getattr_static(
1380+
self.value.__class__, name, NO_SUCH_SUBOBJ
1381+
)
1382+
is_accessible_from_type_mro = (
1383+
subobj_from_class is subobj and self.cls_source is not None
1384+
)
1385+
13421386
if isinstance(subobj, property):
13431387
if self.source:
13441388
# Read the class attribute to reach the property
1345-
source = AttrSource(AttrSource(self.source, "__class__"), name)
1389+
source = self.get_source_by_walking_mro(name)
13461390
# Get the getter function
13471391
source = AttrSource(source, "fget")
13481392
return variables.UserMethodVariable(
@@ -1360,6 +1404,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13601404
# Safe because `staticmethod.__get__` basically won't trigger user
13611405
# code and just returns the underlying `__func__`:
13621406
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
1407+
if is_accessible_from_type_mro:
1408+
# Accessing from __dict__ does not resolve the descriptor, it
1409+
# returns a staticmethod object, so access the __func__
1410+
# attribute to get to the actual function.
1411+
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
13631412
func = subobj.__get__(self.value)
13641413
return VariableTracker.build(tx, func, source)
13651414
elif isinstance(subobj, classmethod):
@@ -1485,10 +1534,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
14851534
source = self._wrap_source(source)
14861535

14871536
if subobj is not NO_SUCH_SUBOBJ:
1488-
if is_wrapper_or_member_descriptor(subobj):
1537+
if (
1538+
is_wrapper_or_member_descriptor(subobj)
1539+
or torch._C._dynamo.utils.is_instancemethod(subobj)
1540+
or is_cython_function(subobj)
1541+
):
14891542
options = {"source": source}
14901543
return variables.GetAttrVariable(self, name, **options)
14911544
if source:
1545+
if is_accessible_from_type_mro:
1546+
source = self.get_source_by_walking_mro(name)
1547+
14921548
return variables.LazyVariableTracker.create(subobj, source)
14931549
else:
14941550
# Check if the subobj is accessible from the class itself. If the class source is known, we can create a

0 commit comments

Comments
 (0)