@@ -154,7 +154,8 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
154
154
155
155
@overload
156
156
@classmethod
157
- def from_type (cls , type_ : type , * , strict : bool = False ) -> Optional [Self ]: ...
157
+ def from_type (cls , type_ : type , * ,
158
+ strict : bool = False ) -> Optional [Self ]: ...
158
159
159
160
@classmethod
160
161
def from_type (
@@ -165,7 +166,8 @@ def from_type(
165
166
) -> Optional [Self ]:
166
167
definition = getattr (type_ , cls .TYPE_KEY_NAME , None )
167
168
if strict and definition is None :
168
- raise TypeError (f"{ type_ !r} does not have a StrawberrySQLAlchemyType in it" )
169
+ raise TypeError (
170
+ f"{ type_ !r} does not have a StrawberrySQLAlchemyType in it" )
169
171
return definition
170
172
171
173
@@ -228,8 +230,10 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):
228
230
229
231
def __init__ (
230
232
self ,
231
- model_to_type_name : Optional [Callable [[Type [BaseModelType ]], str ]] = None ,
232
- model_to_interface_name : Optional [Callable [[Type [BaseModelType ]], str ]] = None ,
233
+ model_to_type_name : Optional [Callable [[
234
+ Type [BaseModelType ]], str ]] = None ,
235
+ model_to_interface_name : Optional [Callable [[
236
+ Type [BaseModelType ]], str ]] = None ,
233
237
extra_sqlalchemy_type_to_strawberry_type_map : Optional [
234
238
Mapping [Type [TypeEngine ], Type [Any ]]
235
239
] = None ,
@@ -295,7 +299,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
295
299
"""
296
300
edge_name = f"{ type_name } Edge"
297
301
if edge_name not in self .edge_types :
298
- lazy_type = StrawberrySQLAlchemyLazy (type_name = type_name , mapper = self )
302
+ lazy_type = StrawberrySQLAlchemyLazy (
303
+ type_name = type_name , mapper = self )
299
304
self .edge_types [edge_name ] = edge_type = strawberry .type (
300
305
dataclasses .make_dataclass (
301
306
edge_name ,
@@ -314,14 +319,16 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
314
319
connection_name = f"{ type_name } Connection"
315
320
if connection_name not in self .connection_types :
316
321
edge_type = self ._edge_type_for (type_name )
317
- lazy_type = StrawberrySQLAlchemyLazy (type_name = type_name , mapper = self )
322
+ lazy_type = StrawberrySQLAlchemyLazy (
323
+ type_name = type_name , mapper = self )
318
324
self .connection_types [connection_name ] = connection_type = strawberry .type (
319
325
dataclasses .make_dataclass (
320
326
connection_name ,
321
327
[
322
328
("edges" , List [edge_type ]), # type: ignore[valid-type]
323
329
],
324
- bases = (relay .ListConnection [lazy_type ],), # type: ignore[valid-type]
330
+ # type: ignore[valid-type]
331
+ bases = (relay .ListConnection [lazy_type ],),
325
332
)
326
333
)
327
334
setattr (connection_type , _GENERATED_FIELD_KEYS_KEY , ["edges" ])
@@ -387,7 +394,7 @@ def _convert_relationship_to_strawberry_type(
387
394
if relationship .uselist :
388
395
# Use list if excluding relay pagination
389
396
if use_list :
390
- return List [ForwardRef (type_name )] # type: ignore
397
+ return List [ForwardRef (type_name )] # type: ignore
391
398
392
399
return self ._connection_type_for (type_name )
393
400
else :
@@ -451,7 +458,8 @@ def _get_association_proxy_annotation(
451
458
strawberry_type .__forward_arg__
452
459
)
453
460
else :
454
- strawberry_type = self ._connection_type_for (strawberry_type .__name__ )
461
+ strawberry_type = self ._connection_type_for (
462
+ strawberry_type .__name__ )
455
463
return strawberry_type
456
464
457
465
def make_connection_wrapper_resolver (
@@ -500,13 +508,31 @@ async def resolve(self, info: Info):
500
508
if relationship .key not in instance_state .unloaded :
501
509
related_objects = getattr (self , relationship .key )
502
510
else :
503
- relationship_key = tuple (
504
- [
505
- getattr (self , local .key )
506
- for local , _ in relationship .local_remote_pairs or []
507
- if local .key
508
- ]
509
- )
511
+ if relationship .secondary is None :
512
+ relationship_key = tuple (
513
+ [
514
+ getattr (self , local .key )
515
+ for local , _ in relationship .local_remote_pairs or []
516
+ if local .key
517
+ ]
518
+ )
519
+ else :
520
+ # If has a secondary table, gets only the first id since the other id cannot be get without a query
521
+ # breakpoint()
522
+ # local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
523
+ # 0][0]
524
+ # relationship_key = tuple(
525
+ # [getattr(self, local_remote_pairs_secondary_table_local.key),]
526
+ # )
527
+ relationship_key = tuple (
528
+ [
529
+ getattr (self , local .key )
530
+ for local , _ in relationship .local_remote_pairs or []
531
+ if local .key
532
+ ]
533
+ )
534
+ # breakpoint()
535
+
510
536
if any (item is None for item in relationship_key ):
511
537
if relationship .uselist :
512
538
return []
@@ -516,6 +542,7 @@ async def resolve(self, info: Info):
516
542
loader = info .context ["sqlalchemy_loader" ]
517
543
else :
518
544
loader = info .context .sqlalchemy_loader
545
+ # breakpoint()
519
546
related_objects = await loader .loader_for (relationship ).load (
520
547
relationship_key
521
548
)
@@ -536,7 +563,8 @@ def connection_resolver_for(
536
563
if relationship .uselist and not use_list :
537
564
return self .make_connection_wrapper_resolver (
538
565
relationship_resolver ,
539
- self .model_to_type_or_interface_name (relationship .entity .entity ), # type: ignore[arg-type]
566
+ self .model_to_type_or_interface_name (
567
+ relationship .entity .entity ), # type: ignore[arg-type]
540
568
)
541
569
else :
542
570
return relationship_resolver
@@ -554,13 +582,15 @@ def association_proxy_resolver_for(
554
582
Return an async field resolver for the given association proxy.
555
583
"""
556
584
in_between_relationship = mapper .relationships [descriptor .target_collection ]
557
- in_between_resolver = self .relationship_resolver_for (in_between_relationship )
585
+ in_between_resolver = self .relationship_resolver_for (
586
+ in_between_relationship )
558
587
in_between_mapper : Mapper = mapper .relationships [ # type: ignore[assignment]
559
588
descriptor .target_collection
560
589
].entity
561
590
assert descriptor .value_attr in in_between_mapper .relationships
562
591
end_relationship = in_between_mapper .relationships [descriptor .value_attr ]
563
- end_relationship_resolver = self .relationship_resolver_for (end_relationship )
592
+ end_relationship_resolver = self .relationship_resolver_for (
593
+ end_relationship )
564
594
end_type_name = self .model_to_type_or_interface_name (
565
595
end_relationship .entity .entity # type: ignore[arg-type]
566
596
)
@@ -587,7 +617,8 @@ async def resolve(self, info: Info):
587
617
if outputs and isinstance (outputs [0 ], list ):
588
618
outputs = list (chain .from_iterable (outputs ))
589
619
else :
590
- outputs = [output for output in outputs if output is not None ]
620
+ outputs = [
621
+ output for output in outputs if output is not None ]
591
622
else :
592
623
outputs = await end_relationship_resolver (in_between_objects , info )
593
624
if not isinstance (outputs , collections .abc .Iterable ):
@@ -683,7 +714,8 @@ def convert(type_: Any) -> Any:
683
714
setattr (type_ , key , field (resolver = val ))
684
715
generated_field_keys .append (key )
685
716
686
- self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
717
+ self ._handle_columns (
718
+ mapper , type_ , excluded_keys , generated_field_keys )
687
719
relationship : RelationshipProperty
688
720
for key , relationship in mapper .relationships .items ():
689
721
if (
@@ -805,7 +837,8 @@ def convert(type_: Any) -> Any:
805
837
setattr (
806
838
type_ ,
807
839
attr ,
808
- types .MethodType (func , type_ ), # type: ignore[arg-type]
840
+ # type: ignore[arg-type]
841
+ types .MethodType (func , type_ ),
809
842
)
810
843
811
844
# Adjust types that inherit from other types/interfaces that implement Node
@@ -818,7 +851,8 @@ def convert(type_: Any) -> Any:
818
851
setattr (
819
852
type_ ,
820
853
attr ,
821
- types .MethodType (cast (classmethod , meth ).__func__ , type_ ),
854
+ types .MethodType (
855
+ cast (classmethod , meth ).__func__ , type_ ),
822
856
)
823
857
824
858
# need to make fields that are already in the type
@@ -846,7 +880,8 @@ def convert(type_: Any) -> Any:
846
880
model = model ,
847
881
),
848
882
)
849
- setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
883
+ setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY ,
884
+ generated_field_keys )
850
885
setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
851
886
return mapped_type
852
887
@@ -886,14 +921,16 @@ def _fix_annotation_namespaces(self) -> None:
886
921
self .edge_types .values (),
887
922
self .connection_types .values (),
888
923
):
889
- strawberry_definition = get_object_definition (mapped_type , strict = True )
924
+ strawberry_definition = get_object_definition (
925
+ mapped_type , strict = True )
890
926
for f in strawberry_definition .fields :
891
927
if f .name in getattr (mapped_type , _GENERATED_FIELD_KEYS_KEY ):
892
928
namespace = {}
893
929
if hasattr (mapped_type , _ORIGINAL_TYPE_KEY ):
894
930
namespace .update (
895
931
sys .modules [
896
- getattr (mapped_type , _ORIGINAL_TYPE_KEY ).__module__
932
+ getattr (mapped_type ,
933
+ _ORIGINAL_TYPE_KEY ).__module__
897
934
].__dict__
898
935
)
899
936
namespace .update (self .mapped_types )
@@ -924,7 +961,8 @@ def _map_unmapped_relationships(self) -> None:
924
961
if type_name not in self .mapped_interfaces :
925
962
unmapped_interface_models .add (model )
926
963
for model in unmapped_models :
927
- self .type (model )(type (self .model_to_type_name (model ), (object ,), {}))
964
+ self .type (model )(
965
+ type (self .model_to_type_name (model ), (object ,), {}))
928
966
for model in unmapped_interface_models :
929
967
self .interface (model )(
930
968
type (self .model_to_interface_name (model ), (object ,), {})
0 commit comments