6060 AttrSource ,
6161 CallFunctionNoArgsSource ,
6262 DataclassFieldsSource ,
63+ DictGetItemSource ,
6364 GetItemSource ,
6465 RandomValueSource ,
66+ TypeDictSource ,
67+ TypeMROSource ,
6568 TypeSource ,
6669 UnspecializedParamBufferSource ,
6770)
@@ -135,6 +138,14 @@ def is_forbidden_context_manager(ctx):
135138 return ctx in f_ctxs
136139
137140
141+ def is_cython_function (obj ):
142+ return (
143+ callable (obj )
144+ and hasattr (type (obj ), "__name__" )
145+ and type (obj ).__name__ == "cython_function_or_method"
146+ )
147+
148+
138149class UserDefinedVariable (VariableTracker ):
139150 value : object
140151
@@ -998,11 +1009,9 @@ def call_method(
9981009
9991010 # check for methods implemented in C++
10001011 if isinstance (method , types .FunctionType ):
1001- source = (
1002- None
1003- if self .source is None
1004- else AttrSource (AttrSource (self .source , "__class__" ), name )
1005- )
1012+ source = None
1013+ if self .source :
1014+ source = self .get_source_by_walking_mro (name )
10061015 # TODO(jansel): add a guard to check for monkey patching?
10071016 from ..mutation_guard import unpatched_nn_module_init
10081017
@@ -1224,12 +1233,40 @@ def get_source_by_walking_mro(self, name):
12241233
12251234 for idx , klass in enumerate (type (self .value ).__mro__ ):
12261235 if name in klass .__dict__ :
1227- mro_source = AttrSource (self .cls_source , "__mro__" )
1228- klass_source = GetItemSource (mro_source , idx )
1229- dict_source = AttrSource (klass_source , "__dict__" )
1230- # TODO(anijain2305) - This is a mapping proxy object. Ideally we
1231- # should use DictGetItemSource here.
1232- return GetItemSource (dict_source , name )
1236+ if idx != 0 :
1237+ mro_source = TypeMROSource (self .cls_source )
1238+ klass_source = GetItemSource (mro_source , idx )
1239+ else :
1240+ klass_source = self .cls_source
1241+ dict_source = TypeDictSource (klass_source )
1242+ out_source = DictGetItemSource (dict_source , name )
1243+
1244+ for absent_idx in range (1 , idx ):
1245+ # Insert a guard that the name is not present in the mro hierarchy
1246+ mro_source = TypeMROSource (self .cls_source )
1247+ klass_source = GetItemSource (mro_source , absent_idx )
1248+ dict_source = TypeDictSource (klass_source )
1249+ install_guard (
1250+ dict_source .make_guard (
1251+ functools .partial (
1252+ GuardBuilder .DICT_CONTAINS , key = name , invert = True
1253+ )
1254+ )
1255+ )
1256+ # Insert a guard that the name is not present in the object __dict__
1257+ if (
1258+ self .source
1259+ and hasattr (self .value , "__dict__" )
1260+ and name not in self .value .__dict__
1261+ ):
1262+ install_guard (
1263+ self .source .make_guard (
1264+ functools .partial (
1265+ GuardBuilder .NOT_PRESENT_IN_GENERIC_DICT , attr = name
1266+ )
1267+ )
1268+ )
1269+ return out_source
12331270
12341271 unimplemented_v2 (
12351272 gb_type = "could not find name in object's mro" ,
@@ -1339,10 +1376,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13391376 if subobj is torch .nn .Module .__init__ :
13401377 subobj = unpatched_nn_module_init
13411378
1379+ subobj_from_class = inspect .getattr_static (
1380+ self .value .__class__ , name , NO_SUCH_SUBOBJ
1381+ )
1382+ is_accessible_from_type_mro = (
1383+ subobj_from_class is subobj and self .cls_source is not None
1384+ )
1385+
13421386 if isinstance (subobj , property ):
13431387 if self .source :
13441388 # Read the class attribute to reach the property
1345- source = AttrSource ( AttrSource ( self .source , "__class__" ), name )
1389+ source = self .get_source_by_walking_mro ( name )
13461390 # Get the getter function
13471391 source = AttrSource (source , "fget" )
13481392 return variables .UserMethodVariable (
@@ -1360,6 +1404,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13601404 # Safe because `staticmethod.__get__` basically won't trigger user
13611405 # code and just returns the underlying `__func__`:
13621406 # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
1407+ if is_accessible_from_type_mro :
1408+ # Accessing from __dict__ does not resolve the descriptor, it
1409+ # returns a staticmethod object, so access the __func__
1410+ # attribute to get to the actual function.
1411+ source = AttrSource (self .get_source_by_walking_mro (name ), "__func__" )
13631412 func = subobj .__get__ (self .value )
13641413 return VariableTracker .build (tx , func , source )
13651414 elif isinstance (subobj , classmethod ):
@@ -1485,10 +1534,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
14851534 source = self ._wrap_source (source )
14861535
14871536 if subobj is not NO_SUCH_SUBOBJ :
1488- if is_wrapper_or_member_descriptor (subobj ):
1537+ if (
1538+ is_wrapper_or_member_descriptor (subobj )
1539+ or torch ._C ._dynamo .utils .is_instancemethod (subobj )
1540+ or is_cython_function (subobj )
1541+ ):
14891542 options = {"source" : source }
14901543 return variables .GetAttrVariable (self , name , ** options )
14911544 if source :
1545+ if is_accessible_from_type_mro :
1546+ source = self .get_source_by_walking_mro (name )
1547+
14921548 return variables .LazyVariableTracker .create (subobj , source )
14931549 else :
14941550 # Check if the subobj is accessible from the class itself. If the class source is known, we can create a
0 commit comments