Skip to content

Commit ee62177

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Be consistent with storing func source for UserMethodVariable (pytorch#159696)
Pull Request resolved: pytorch#159696 Approved by: https://github.com/jansel ghstack dependencies: pytorch#159534
1 parent 64cbaa8 commit ee62177

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

torch/_dynamo/codegen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .variables.functions import (
4343
ContextlibContextManagerLocalGeneratorObjectVariable,
4444
LocalGeneratorObjectVariable,
45+
UserMethodVariable,
4546
)
4647
from .variables.nn_module import NNModuleVariable
4748
from .variables.tensor import (
@@ -250,7 +251,10 @@ def __call__(
250251
value.source is not None
251252
and allow_cache
252253
and not (
253-
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
254+
value.is_realized()
255+
and isinstance(
256+
value, (LocalGeneratorObjectVariable, UserMethodVariable)
257+
)
254258
)
255259
):
256260
# There's a corner case for export: for instance, if the computation

torch/_dynamo/variables/functions.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,13 +1122,26 @@ def inspect_parameter_names(self):
11221122
return super().inspect_parameter_names()[1:]
11231123

11241124
def var_getattr(self, tx: "InstructionTranslator", name: str):
1125-
source = self.source and AttrSource(self.source, name)
1125+
if name == "__func__":
1126+
# self.source points to the source of the function object and not
1127+
# the method object
1128+
return VariableTracker.build(tx, self.fn, self.source)
11261129
if name == "__self__":
11271130
return self.obj
1128-
if name == "__func__":
1129-
return VariableTracker.build(tx, self.fn, source)
11301131
return super().var_getattr(tx, name)
11311132

1133+
def reconstruct(self, codegen):
1134+
if not self.obj.source or not self.source:
1135+
raise NotImplementedError
1136+
1137+
def get_bound_method():
1138+
codegen(self.source)
1139+
codegen.extend_output(codegen.create_load_attrs("__get__"))
1140+
1141+
codegen.add_push_null(get_bound_method)
1142+
codegen(self.obj.source)
1143+
codegen.extend_output(create_call_function(1, False))
1144+
11321145

11331146
class WrappedUserMethodVariable(UserMethodVariable):
11341147
def __init__(self, wrapped, context, **kwargs) -> None:

torch/_dynamo/variables/user_defined.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,9 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13801380
self.value.__class__, name, NO_SUCH_SUBOBJ
13811381
)
13821382
is_accessible_from_type_mro = (
1383-
subobj_from_class is subobj and self.cls_source is not None
1383+
subobj_from_class is subobj
1384+
and self.cls_source is not None
1385+
and self.source is not None
13841386
)
13851387

13861388
if isinstance(subobj, property):
@@ -1412,6 +1414,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
14121414
func = subobj.__get__(self.value)
14131415
return VariableTracker.build(tx, func, source)
14141416
elif isinstance(subobj, classmethod):
1417+
if is_accessible_from_type_mro:
1418+
# Accessing from __dict__ does not resolve the descriptor, it
1419+
# returns a classmethod object, so access the __func__
1420+
# attribute to get to the actual function.
1421+
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
14151422
return variables.UserMethodVariable(
14161423
subobj.__func__, self.var_getattr(tx, "__class__"), source=source
14171424
)
@@ -1461,6 +1468,9 @@ def var_getattr(self, tx: "InstructionTranslator", name):
14611468
isinstance(subobj, types.MethodType)
14621469
and isinstance(self.value, torch.nn.Module)
14631470
):
1471+
if is_accessible_from_type_mro:
1472+
source = self.get_source_by_walking_mro(name)
1473+
14641474
# Since we get subobj via self._getattr_static, which may not trigger dynamic lookup.
14651475
# Static lookup can't tell us it's a method or function correctly,
14661476
# so we trigger dynamic lookup here to get the correct type.

0 commit comments

Comments
 (0)