69
69
Mapper ,
70
70
RelationshipProperty ,
71
71
)
72
- from sqlalchemy .orm .state import InstanceState
73
72
from sqlalchemy .sql .type_api import TypeEngine
74
73
from strawberry import relay
75
74
from strawberry .annotation import StrawberryAnnotation
76
75
from strawberry .scalars import JSON as StrawberryJSON
77
76
from strawberry .types import Info
78
77
from strawberry .types .base import WithStrawberryObjectDefinition , get_object_definition
79
- from strawberry .types .field import StrawberryField
80
78
from strawberry .types .lazy_type import LazyType
81
79
from strawberry .types .private import is_private
82
80
97
95
from strawberry_sqlalchemy_mapper .scalars import BigInt
98
96
99
97
if TYPE_CHECKING :
98
+ from sqlalchemy .orm .state import InstanceState
100
99
from sqlalchemy .sql .expression import ColumnElement
100
+ from strawberry .types .field import StrawberryField
101
101
102
102
BaseModelType = TypeVar ("BaseModelType" )
103
103
104
104
SkipTypeSentinelT = NewType ("SkipTypeSentinelT" , object )
105
- SkipTypeSentinel = cast (SkipTypeSentinelT , sentinel .create ("SkipTypeSentinel" ))
106
-
105
+ SkipTypeSentinel = cast ("SkipTypeSentinelT" , sentinel .create ("SkipTypeSentinel" ))
107
106
108
107
#: Set on generated types to the original type handed to the decorator
109
108
_ORIGINAL_TYPE_KEY = "_original_type"
@@ -151,13 +150,11 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):
151
150
152
151
@overload
153
152
@classmethod
154
- def from_type (cls , type_ : type , * , strict : Literal [True ]) -> Self :
155
- ...
153
+ def from_type (cls , type_ : type , * , strict : Literal [True ]) -> Self : ...
156
154
157
155
@overload
158
156
@classmethod
159
- def from_type (cls , type_ : type , * , strict : bool = False ) -> Optional [Self ]:
160
- ...
157
+ def from_type (cls , type_ : type , * , strict : bool = False ) -> Optional [Self ]: ...
161
158
162
159
@classmethod
163
160
def from_type (
@@ -468,28 +465,34 @@ def make_connection_wrapper_resolver(
468
465
edge_type = self ._edge_type_for (type_name )
469
466
470
467
async def wrapper (self , info : Info ):
471
- # TODO: Add pagination support to dataloader resolvers
472
- edges = [
473
- edge_type .resolve_edge (
474
- related_object ,
475
- cursor = i ,
476
- )
477
- for i , related_object in enumerate (await resolver (self , info ))
478
- ]
479
- return connection_type (
480
- edges = edges ,
481
- page_info = relay .PageInfo (
482
- has_next_page = False ,
483
- has_previous_page = False ,
484
- start_cursor = edges [0 ].cursor if edges else None ,
485
- end_cursor = edges [- 1 ].cursor if edges else None ,
486
- ),
468
+ return StrawberrySQLAlchemyMapper ._resolve_connection_edges (
469
+ await resolver (self , info ), edge_type , connection_type
487
470
)
488
471
489
472
setattr (wrapper , _IS_GENERATED_RESOLVER_KEY , True )
490
473
491
474
return wrapper
492
475
476
+ @staticmethod
477
+ def _resolve_connection_edges (related_objects , edge_type , connection_type ):
478
+ # TODO: Add pagination support to dataloader resolvers
479
+ edges = [
480
+ edge_type .resolve_edge (
481
+ related_object ,
482
+ cursor = i ,
483
+ )
484
+ for i , related_object in enumerate (related_objects )
485
+ ]
486
+ return connection_type (
487
+ edges = edges ,
488
+ page_info = relay .PageInfo (
489
+ has_next_page = False ,
490
+ has_previous_page = False ,
491
+ start_cursor = edges [0 ].cursor if edges else None ,
492
+ end_cursor = edges [- 1 ].cursor if edges else None ,
493
+ ),
494
+ )
495
+
493
496
def relationship_resolver_for (
494
497
self , relationship : RelationshipProperty
495
498
) -> Callable [..., Awaitable [Any ]]:
@@ -499,7 +502,7 @@ def relationship_resolver_for(
499
502
"""
500
503
501
504
async def resolve (self , info : Info ):
502
- instance_state = cast (InstanceState , inspect (self ))
505
+ instance_state = cast (" InstanceState" , inspect (self ))
503
506
if relationship .key not in instance_state .unloaded :
504
507
related_objects = getattr (self , relationship .key )
505
508
else :
@@ -575,7 +578,15 @@ async def resolve(self, info: Info):
575
578
in_between_objects = await in_between_resolver (self , info )
576
579
if in_between_objects is None :
577
580
if is_multiple :
578
- return connection_type (edges = [])
581
+ return connection_type (
582
+ edges = [],
583
+ page_info = relay .PageInfo (
584
+ has_next_page = False ,
585
+ has_previous_page = False ,
586
+ start_cursor = None ,
587
+ end_cursor = None ,
588
+ ),
589
+ )
579
590
else :
580
591
return None
581
592
if descriptor .value_attr in in_between_mapper .relationships :
@@ -595,7 +606,9 @@ async def resolve(self, info: Info):
595
606
outputs = await end_relationship_resolver (in_between_objects , info )
596
607
if not isinstance (outputs , collections .abc .Iterable ):
597
608
return outputs
598
- return connection_type (edges = [edge_type (node = obj ) for obj in outputs ])
609
+ return StrawberrySQLAlchemyMapper ._resolve_connection_edges (
610
+ outputs , edge_type , connection_type
611
+ )
599
612
else :
600
613
assert descriptor .value_attr in in_between_mapper .columns
601
614
if isinstance (in_between_objects , collections .abc .Iterable ):
@@ -668,7 +681,7 @@ def convert(type_: Any) -> Any:
668
681
type_ .__annotations__ = {
669
682
k : v for k , v in old_annotations .items () if is_private (v )
670
683
}
671
- mapper : Mapper = cast (Mapper , inspect (model ))
684
+ mapper : Mapper = cast (" Mapper" , inspect (model ))
672
685
generated_field_keys = []
673
686
674
687
excluded_keys = getattr (type_ , "__exclude__" , [])
@@ -707,7 +720,7 @@ def convert(type_: Any) -> Any:
707
720
generated_field_keys ,
708
721
)
709
722
sqlalchemy_field = cast (
710
- StrawberryField ,
723
+ " StrawberryField" ,
711
724
field (
712
725
resolver = self .connection_resolver_for (
713
726
relationship ,
@@ -751,7 +764,7 @@ def convert(type_: Any) -> Any:
751
764
type_ , key , strawberry_type , generated_field_keys
752
765
)
753
766
sqlalchemy_field = cast (
754
- StrawberryField ,
767
+ " StrawberryField" ,
755
768
field (
756
769
resolver = self .association_proxy_resolver_for (
757
770
mapper ,
@@ -789,7 +802,8 @@ def convert(type_: Any) -> Any:
789
802
# ignore inherited `is_type_of`
790
803
if "is_type_of" not in type_ .__dict__ :
791
804
type_ .is_type_of = (
792
- lambda obj , info : type (obj ) == model or type (obj ) == type_
805
+ lambda obj , info : type (obj ) == model # noqa: E721
806
+ or type (obj ) == type_ # noqa: E721
793
807
)
794
808
795
809
# Default querying methods for relay
@@ -822,7 +836,7 @@ def convert(type_: Any) -> Any:
822
836
setattr (
823
837
type_ ,
824
838
attr ,
825
- types .MethodType (cast (classmethod , meth ).__func__ , type_ ),
839
+ types .MethodType (cast (" classmethod" , meth ).__func__ , type_ ),
826
840
)
827
841
828
842
# need to make fields that are already in the type
0 commit comments