1
+ import ast
1
2
import asyncio
2
3
import collections .abc
3
4
import dataclasses
7
8
from decimal import Decimal
8
9
from itertools import chain
9
10
from typing import (
11
+ TYPE_CHECKING ,
10
12
Any ,
11
13
Awaitable ,
12
14
Callable ,
62
64
from sqlalchemy .orm .state import InstanceState
63
65
from sqlalchemy .sql .type_api import TypeEngine
64
66
from strawberry .annotation import StrawberryAnnotation
67
+ from strawberry .field import StrawberryField
65
68
from strawberry .types import Info
66
69
67
70
from strawberry_sqlalchemy_mapper .exc import (
72
75
UnsupportedDescriptorType ,
73
76
)
74
77
78
+ if TYPE_CHECKING :
79
+ from sqlalchemy .orm .sql .elements import ColumnElement
80
+
75
81
BaseModelType = TypeVar ("BaseModelType" )
76
82
77
- SkipTypeSentinelT = NewType ("SkipType " , object )
83
+ SkipTypeSentinelT = NewType ("SkipTypeSentinelT " , object )
78
84
SkipTypeSentinel = cast (SkipTypeSentinelT , sentinel .create ("SkipTypeSentinel" ))
79
85
80
86
@@ -204,7 +210,7 @@ def _is_model_polymorphic(self, model: Type[BaseModelType]) -> bool:
204
210
"""
205
211
Whether a model is part of a polymorphic hierarchy
206
212
"""
207
- return inspect (model ).polymorphic_on is not None
213
+ return inspect (model ).polymorphic_on is not None # type: ignore[union-attr]
208
214
209
215
def _edge_type_for (self , type_name : str ) -> Type [Any ]:
210
216
"""
@@ -249,7 +255,7 @@ def _get_polymorphic_base_model(
249
255
"""
250
256
Given a model, return the base of its inheritance tree (which may be itself).
251
257
"""
252
- return inspect (model ).base_mapper .entity
258
+ return inspect (model ).base_mapper .entity # type: ignore[union-attr]
253
259
254
260
def _convert_column_to_strawberry_type (
255
261
self , column : Column
@@ -280,7 +286,7 @@ def _convert_column_to_strawberry_type(
280
286
if type_annotation is SkipTypeSentinel :
281
287
return type_annotation
282
288
if column .nullable :
283
- type_annotation = Optional [type_annotation ] # type: ignore
289
+ type_annotation = Optional [type_annotation ]
284
290
assert type_annotation is not None
285
291
return type_annotation
286
292
@@ -291,7 +297,7 @@ def _convert_relationship_to_strawberry_type(
291
297
Given a SQLAlchemy relationship, return the type annotation for the field in the
292
298
corresponding strawberry type.
293
299
"""
294
- relationship_model : Type [BaseModelType ] = relationship .entity .entity
300
+ relationship_model : Type [BaseModelType ] = relationship .entity .entity # type: ignore[assignment]
295
301
type_name = self .model_to_type_or_interface_name (relationship_model )
296
302
if self .model_is_interface (relationship_model ):
297
303
self ._related_interface_models .add (relationship_model )
@@ -315,8 +321,8 @@ def _get_relationship_is_optional(self, relationship: RelationshipProperty) -> b
315
321
else :
316
322
assert relationship .direction == MANYTOONE
317
323
# this model is the one with the FK
318
- for local_col , _ in relationship . local_remote_pairs :
319
- local_col : Column
324
+ local_col : ColumnElement
325
+ for local_col , _ in relationship . local_remote_pairs or []:
320
326
if local_col .nullable :
321
327
return True
322
328
return False
@@ -339,7 +345,7 @@ def _get_association_proxy_annotation(
339
345
are of the form (relationship, relationship).
340
346
"""
341
347
is_multiple = mapper .relationships [descriptor .target_collection ].uselist
342
- in_between_mapper : Mapper = mapper .relationships [
348
+ in_between_mapper : Mapper = mapper .relationships [ # type: ignore[assignment]
343
349
descriptor .target_collection
344
350
].entity
345
351
if descriptor .value_attr in in_between_mapper .relationships :
@@ -352,17 +358,13 @@ def _get_association_proxy_annotation(
352
358
raise UnsupportedAssociationProxyTarget (key )
353
359
if strawberry_type is SkipTypeSentinel :
354
360
return strawberry_type
355
- if is_multiple and not self ._is_connection_type (
356
- cast (Union [Type [Any ], ForwardRef ], strawberry_type )
357
- ):
361
+ if is_multiple and not self ._is_connection_type (strawberry_type ):
358
362
if isinstance (strawberry_type , ForwardRef ):
359
363
strawberry_type = self ._connection_type_for (
360
364
strawberry_type .__forward_arg__
361
365
)
362
366
else :
363
- strawberry_type = self ._connection_type_for (
364
- cast (Type [Any ], strawberry_type ).__name__
365
- )
367
+ strawberry_type = self ._connection_type_for (strawberry_type .__name__ )
366
368
return strawberry_type
367
369
368
370
def make_connection_wrapper_resolver (
@@ -405,7 +407,8 @@ async def resolve(self, info: Info):
405
407
relationship_key = tuple (
406
408
[
407
409
getattr (self , local .key )
408
- for local , _ in relationship .local_remote_pairs
410
+ for local , _ in relationship .local_remote_pairs or []
411
+ if local .key
409
412
]
410
413
)
411
414
if any (item is None for item in relationship_key ):
@@ -437,7 +440,7 @@ def connection_resolver_for(
437
440
if relationship .uselist :
438
441
return self .make_connection_wrapper_resolver (
439
442
relationship_resolver ,
440
- self .model_to_type_or_interface_name (relationship .entity .entity ),
443
+ self .model_to_type_or_interface_name (relationship .entity .entity ), # type: ignore[arg-type]
441
444
)
442
445
else :
443
446
return relationship_resolver
@@ -456,14 +459,14 @@ def association_proxy_resolver_for(
456
459
"""
457
460
in_between_relationship = mapper .relationships [descriptor .target_collection ]
458
461
in_between_resolver = self .relationship_resolver_for (in_between_relationship )
459
- in_between_mapper : Mapper = mapper .relationships [
462
+ in_between_mapper : Mapper = mapper .relationships [ # type: ignore[assignment]
460
463
descriptor .target_collection
461
464
].entity
462
465
assert descriptor .value_attr in in_between_mapper .relationships
463
466
end_relationship = in_between_mapper .relationships [descriptor .value_attr ]
464
467
end_relationship_resolver = self .relationship_resolver_for (end_relationship )
465
468
end_type_name = self .model_to_type_or_interface_name (
466
- end_relationship .entity .entity
469
+ end_relationship .entity .entity # type: ignore[arg-type]
467
470
)
468
471
connection_type = self ._connection_type_for (end_type_name )
469
472
edge_type = self ._edge_type_for (end_type_name )
@@ -517,8 +520,8 @@ def _handle_columns(
517
520
"""
518
521
Add annotations for the columns of the given mapper.
519
522
"""
523
+ column : Column
520
524
for key , column in mapper .columns .items ():
521
- column : Column
522
525
if (
523
526
key in excluded_keys
524
527
or key in type_ .__annotations__
@@ -563,7 +566,7 @@ class Employee:
563
566
def convert (type_ : Any ) -> Any :
564
567
old_annotations = getattr (type_ , "__annotations__" , {})
565
568
type_ .__annotations__ = {}
566
- mapper : Mapper = inspect (model )
569
+ mapper : Mapper = cast ( Mapper , inspect (model ) )
567
570
generated_field_keys = []
568
571
569
572
excluded_keys = getattr (type_ , "__exclude__" , [])
@@ -582,8 +585,8 @@ def convert(type_: Any) -> Any:
582
585
generated_field_keys .append (key )
583
586
584
587
self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
588
+ relationship : RelationshipProperty
585
589
for key , relationship in mapper .relationships .items ():
586
- relationship : RelationshipProperty
587
590
if (
588
591
key in excluded_keys
589
592
or key in type_ .__annotations__
@@ -599,8 +602,11 @@ def convert(type_: Any) -> Any:
599
602
strawberry_type ,
600
603
generated_field_keys ,
601
604
)
602
- field = strawberry .field (
603
- resolver = self .connection_resolver_for (relationship )
605
+ field = cast (
606
+ StrawberryField ,
607
+ strawberry .field (
608
+ resolver = self .connection_resolver_for (relationship )
609
+ ),
604
610
)
605
611
assert not field .init
606
612
setattr (
@@ -618,8 +624,8 @@ def convert(type_: Any) -> Any:
618
624
if key in mapper .columns or key in mapper .relationships :
619
625
continue
620
626
if key in model .__annotations__ :
621
- annotation = eval (model .__annotations__ [key ])
622
- for (
627
+ annotation = ast . literal_eval (model .__annotations__ [key ])
628
+ for ( # type: ignore[assignment]
623
629
sqlalchemy_type ,
624
630
strawberry_type ,
625
631
) in self .sqlalchemy_type_to_strawberry_type_map .items ():
@@ -629,18 +635,21 @@ def convert(type_: Any) -> Any:
629
635
)
630
636
break
631
637
elif isinstance (descriptor , AssociationProxy ):
632
- strawberry_type = self ._get_association_proxy_annotation (
638
+ strawberry_type = self ._get_association_proxy_annotation ( # type: ignore[assignment]
633
639
mapper , key , descriptor
634
640
)
635
641
if strawberry_type is SkipTypeSentinel :
636
642
continue
637
643
self ._add_annotation (
638
644
type_ , key , strawberry_type , generated_field_keys
639
645
)
640
- field = strawberry .field (
641
- resolver = self .association_proxy_resolver_for (
642
- mapper , descriptor , strawberry_type
643
- )
646
+ field = cast (
647
+ StrawberryField ,
648
+ strawberry .field (
649
+ resolver = self .association_proxy_resolver_for (
650
+ mapper , descriptor , strawberry_type # type: ignore[arg-type]
651
+ )
652
+ ),
644
653
)
645
654
assert not field .init
646
655
setattr (type_ , key , field )
@@ -656,7 +665,7 @@ def convert(type_: Any) -> Any:
656
665
if "typing" in annotation :
657
666
# Try to evaluate from existing typing imports
658
667
annotation = annotation [7 :]
659
- annotation = eval (annotation )
668
+ annotation = ast . literal_eval (annotation )
660
669
except NameError :
661
670
raise UnsupportedDescriptorType (key )
662
671
self ._add_annotation (
0 commit comments