44
55from mypy .mro import calculate_mro
66from mypy .nodes import (
7- AssignmentStmt , ClassDef , Expression , ImportedName , Lvalue , MypyFile , NameExpr , SymbolNode , TypeInfo ,
8- SymbolTable , SymbolTableNode , Block , GDEF , MDEF , Var )
9- from mypy .plugin import FunctionContext , MethodContext
7+ GDEF , MDEF , AssignmentStmt , Block , CallExpr , ClassDef , Expression , ImportedName , Lvalue , MypyFile , NameExpr ,
8+ SymbolNode , SymbolTable , SymbolTableNode , TypeInfo , Var ,
9+ )
10+ from mypy .plugin import CheckerPluginInterface , FunctionContext , MethodContext
1011from mypy .types import (
11- AnyType , Instance , NoneTyp , Type , TypeOfAny , TypeVarType , UnionType ,
12- TupleType , TypedDictType )
12+ AnyType , Instance , NoneTyp , TupleType , Type , TypedDictType , TypeOfAny , TypeVarType , UnionType ,
13+ )
1314
1415if typing .TYPE_CHECKING :
1516 from mypy .checker import TypeChecker
@@ -216,6 +217,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
216217
217218
218219def extract_field_getter_type (tp : Type ) -> Optional [Type ]:
220+ """ Extract return type of __get__ of subclass of Field"""
219221 if not isinstance (tp , Instance ):
220222 return None
221223 if tp .type .has_base (FIELD_FULLNAME ):
@@ -226,13 +228,12 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]:
226228 return None
227229
228230
229- def get_django_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
230- return model .metadata .setdefault ('django' , {})
231+ def get_django_metadata (model_info : TypeInfo ) -> Dict [str , typing .Any ]:
232+ return model_info .metadata .setdefault ('django' , {})
231233
232234
233235def get_related_field_primary_key_names (base_model : TypeInfo ) -> typing .List [str ]:
234- django_metadata = get_django_metadata (base_model )
235- return django_metadata .setdefault ('related_field_primary_keys' , [])
236+ return get_django_metadata (base_model ).setdefault ('related_field_primary_keys' , [])
236237
237238
238239def get_fields_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
@@ -243,6 +244,10 @@ def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
243244 return get_django_metadata (model ).setdefault ('lookups' , {})
244245
245246
247+ def get_related_managers_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
248+ return get_django_metadata (model ).setdefault ('related_managers' , {})
249+
250+
246251def extract_explicit_set_type_of_model_primary_key (model : TypeInfo ) -> Optional [Type ]:
247252 """
248253 If field with primary_key=True is set on the model, extract its __set__ type.
@@ -310,7 +315,7 @@ def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
310315 return get_fields_metadata (model ).get (field_name , {}).get ('null' , False )
311316
312317
313- def is_foreign_key (t : Type ) -> bool :
318+ def is_foreign_key_like (t : Type ) -> bool :
314319 if not isinstance (t , Instance ):
315320 return False
316321 return has_any_of_bases (t .type , (FOREIGN_KEY_FULLNAME , ONETOONE_FIELD_FULLNAME ))
@@ -366,13 +371,14 @@ def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name:
366371 return TupleType (list (fields .values ()), fallback = fallback )
367372
368373
369- def make_typeddict (api : 'TypeChecker' , fields : 'OrderedDict[str, Type]' , required_keys : typing .Set [str ]) -> Type :
374+ def make_typeddict (api : CheckerPluginInterface , fields : 'OrderedDict[str, Type]' ,
375+ required_keys : typing .Set [str ]) -> TypedDictType :
370376 object_type = api .named_generic_type ('mypy_extensions._TypedDict' , [])
371377 typed_dict_type = TypedDictType (fields , required_keys = required_keys , fallback = object_type )
372378 return typed_dict_type
373379
374380
375- def make_tuple (api : 'TypeChecker' , fields : typing .List [Type ]) -> Type :
381+ def make_tuple (api : 'TypeChecker' , fields : typing .List [Type ]) -> TupleType :
376382 implicit_any = AnyType (TypeOfAny .special_form )
377383 fallback = api .named_generic_type ('builtins.tuple' , [implicit_any ])
378384 return TupleType (fields , fallback = fallback )
@@ -386,3 +392,52 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
386392 descriptor_type = make_optional (descriptor_type )
387393 return descriptor_type
388394 return AnyType (TypeOfAny .unannotated )
395+
396+
397+ def iter_over_classdefs (module_file : MypyFile ) -> typing .Iterator [ClassDef ]:
398+ for defn in module_file .defs :
399+ if isinstance (defn , ClassDef ):
400+ yield defn
401+
402+
403+ def iter_call_assignments (klass : ClassDef ) -> typing .Iterator [typing .Tuple [Lvalue , CallExpr ]]:
404+ for lvalue , rvalue in iter_over_assignments (klass ):
405+ if isinstance (rvalue , CallExpr ):
406+ yield lvalue , rvalue
407+
408+
409+ def get_related_manager_type_from_metadata (model_info : TypeInfo , related_manager_name : str ,
410+ api : CheckerPluginInterface ) -> Optional [Instance ]:
411+ related_manager_metadata = get_related_managers_metadata (model_info )
412+ if not related_manager_metadata :
413+ return None
414+
415+ if related_manager_name not in related_manager_metadata :
416+ return None
417+
418+ manager_class_name = related_manager_metadata [related_manager_name ]['manager' ]
419+ of = related_manager_metadata [related_manager_name ]['of' ]
420+ of_types = []
421+ for of_type_name in of :
422+ if of_type_name == 'any' :
423+ of_types .append (AnyType (TypeOfAny .implementation_artifact ))
424+ else :
425+ try :
426+ of_type = api .named_generic_type (of_type_name , [])
427+ except AssertionError :
428+ # Internal error: attempted lookup of unknown name
429+ of_type = AnyType (TypeOfAny .implementation_artifact )
430+
431+ of_types .append (of_type )
432+
433+ return api .named_generic_type (manager_class_name , of_types )
434+
435+
436+ def get_primary_key_field_name (model_info : TypeInfo ) -> Optional [str ]:
437+ for base in model_info .mro :
438+ fields = get_fields_metadata (base )
439+ for field_name , field_props in fields .items ():
440+ is_primary_key = field_props .get ('primary_key' , False )
441+ if is_primary_key :
442+ return field_name
443+ return None
0 commit comments