Skip to content

Commit de06188

Browse files
committed
first fix versoin, working only if the items has the same id
1 parent 2133fd7 commit de06188

File tree

3 files changed

+520
-50
lines changed

3 files changed

+520
-50
lines changed

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,57 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
6363
related_model = relationship.entity.entity
6464

6565
async def load_fn(keys: List[Tuple]) -> List[Any]:
66-
query = select(related_model).filter(
67-
tuple_(
68-
*[remote for _, remote in relationship.local_remote_pairs or []]
69-
).in_(keys)
70-
)
66+
if relationship.secondary is None:
67+
query = select(related_model).filter(
68+
tuple_(
69+
*[remote for _, remote in relationship.local_remote_pairs or []]
70+
).in_(keys)
71+
)
72+
else:
73+
# Use another query when relationship uses a secondary table
74+
# *[remote[1] for remote in relationship.local_remote_pairs or []]
75+
# breakpoint()
76+
# remote_to_use = relationship.local_remote_pairs[0][1]
77+
# keys = tuple([item[0] for item in keys])
78+
query = (
79+
select(related_model)
80+
.join(relationship.secondary, relationship.secondaryjoin)
81+
.filter(
82+
# emote_to_use.in_(keys)
83+
tuple_(
84+
*[remote[1] for remote in relationship.local_remote_pairs or []]
85+
).in_(keys)
86+
)
87+
)
88+
7189
if relationship.order_by:
7290
query = query.order_by(*relationship.order_by)
7391
rows = await self._scalars_all(query)
7492

7593
def group_by_remote_key(row: Any) -> Tuple:
76-
return tuple(
77-
[
78-
getattr(row, remote.key)
79-
for _, remote in relationship.local_remote_pairs or []
80-
if remote.key
81-
]
82-
)
94+
if relationship.secondary is None:
95+
return tuple(
96+
[
97+
getattr(row, remote.key)
98+
for _, remote in relationship.local_remote_pairs or []
99+
if remote.key
100+
]
101+
)
102+
else:
103+
# Use another query when relationship uses a secondary table
104+
# breakpoint()
105+
related_model_table = relationship.entity.entity.__table__
106+
# breakpoint()
107+
return tuple(
108+
[
109+
getattr(row, remote[0].key)
110+
for remote in relationship.local_remote_pairs or []
111+
if remote[0].key is not None and remote[0].table == related_model_table
112+
]
113+
)
83114

84115
grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
116+
# breakpoint()
85117
for row in rows:
86118
grouped_keys[group_by_remote_key(row)].append(row)
87119
if relationship.uselist:

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
154154

155155
@overload
156156
@classmethod
157-
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
157+
def from_type(cls, type_: type, *,
158+
strict: bool = False) -> Optional[Self]: ...
158159

159160
@classmethod
160161
def from_type(
@@ -165,7 +166,8 @@ def from_type(
165166
) -> Optional[Self]:
166167
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
167168
if strict and definition is None:
168-
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
169+
raise TypeError(
170+
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
169171
return definition
170172

171173

@@ -228,8 +230,10 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):
228230

229231
def __init__(
230232
self,
231-
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
232-
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
233+
model_to_type_name: Optional[Callable[[
234+
Type[BaseModelType]], str]] = None,
235+
model_to_interface_name: Optional[Callable[[
236+
Type[BaseModelType]], str]] = None,
233237
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
234238
Mapping[Type[TypeEngine], Type[Any]]
235239
] = None,
@@ -295,7 +299,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
295299
"""
296300
edge_name = f"{type_name}Edge"
297301
if edge_name not in self.edge_types:
298-
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
302+
lazy_type = StrawberrySQLAlchemyLazy(
303+
type_name=type_name, mapper=self)
299304
self.edge_types[edge_name] = edge_type = strawberry.type(
300305
dataclasses.make_dataclass(
301306
edge_name,
@@ -314,14 +319,16 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
314319
connection_name = f"{type_name}Connection"
315320
if connection_name not in self.connection_types:
316321
edge_type = self._edge_type_for(type_name)
317-
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
322+
lazy_type = StrawberrySQLAlchemyLazy(
323+
type_name=type_name, mapper=self)
318324
self.connection_types[connection_name] = connection_type = strawberry.type(
319325
dataclasses.make_dataclass(
320326
connection_name,
321327
[
322328
("edges", List[edge_type]), # type: ignore[valid-type]
323329
],
324-
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
330+
# type: ignore[valid-type]
331+
bases=(relay.ListConnection[lazy_type],),
325332
)
326333
)
327334
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
@@ -387,7 +394,7 @@ def _convert_relationship_to_strawberry_type(
387394
if relationship.uselist:
388395
# Use list if excluding relay pagination
389396
if use_list:
390-
return List[ForwardRef(type_name)] # type: ignore
397+
return List[ForwardRef(type_name)] # type: ignore
391398

392399
return self._connection_type_for(type_name)
393400
else:
@@ -451,7 +458,8 @@ def _get_association_proxy_annotation(
451458
strawberry_type.__forward_arg__
452459
)
453460
else:
454-
strawberry_type = self._connection_type_for(strawberry_type.__name__)
461+
strawberry_type = self._connection_type_for(
462+
strawberry_type.__name__)
455463
return strawberry_type
456464

457465
def make_connection_wrapper_resolver(
@@ -500,13 +508,31 @@ async def resolve(self, info: Info):
500508
if relationship.key not in instance_state.unloaded:
501509
related_objects = getattr(self, relationship.key)
502510
else:
503-
relationship_key = tuple(
504-
[
505-
getattr(self, local.key)
506-
for local, _ in relationship.local_remote_pairs or []
507-
if local.key
508-
]
509-
)
511+
if relationship.secondary is None:
512+
relationship_key = tuple(
513+
[
514+
getattr(self, local.key)
515+
for local, _ in relationship.local_remote_pairs or []
516+
if local.key
517+
]
518+
)
519+
else:
520+
# If has a secondary table, gets only the first id since the other id cannot be get without a query
521+
# breakpoint()
522+
# local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
523+
# 0][0]
524+
# relationship_key = tuple(
525+
# [getattr(self, local_remote_pairs_secondary_table_local.key),]
526+
# )
527+
relationship_key = tuple(
528+
[
529+
getattr(self, local.key)
530+
for local, _ in relationship.local_remote_pairs or []
531+
if local.key
532+
]
533+
)
534+
# breakpoint()
535+
510536
if any(item is None for item in relationship_key):
511537
if relationship.uselist:
512538
return []
@@ -516,6 +542,7 @@ async def resolve(self, info: Info):
516542
loader = info.context["sqlalchemy_loader"]
517543
else:
518544
loader = info.context.sqlalchemy_loader
545+
# breakpoint()
519546
related_objects = await loader.loader_for(relationship).load(
520547
relationship_key
521548
)
@@ -536,7 +563,8 @@ def connection_resolver_for(
536563
if relationship.uselist and not use_list:
537564
return self.make_connection_wrapper_resolver(
538565
relationship_resolver,
539-
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
566+
self.model_to_type_or_interface_name(
567+
relationship.entity.entity), # type: ignore[arg-type]
540568
)
541569
else:
542570
return relationship_resolver
@@ -554,13 +582,15 @@ def association_proxy_resolver_for(
554582
Return an async field resolver for the given association proxy.
555583
"""
556584
in_between_relationship = mapper.relationships[descriptor.target_collection]
557-
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
585+
in_between_resolver = self.relationship_resolver_for(
586+
in_between_relationship)
558587
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
559588
descriptor.target_collection
560589
].entity
561590
assert descriptor.value_attr in in_between_mapper.relationships
562591
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
563-
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
592+
end_relationship_resolver = self.relationship_resolver_for(
593+
end_relationship)
564594
end_type_name = self.model_to_type_or_interface_name(
565595
end_relationship.entity.entity # type: ignore[arg-type]
566596
)
@@ -587,7 +617,8 @@ async def resolve(self, info: Info):
587617
if outputs and isinstance(outputs[0], list):
588618
outputs = list(chain.from_iterable(outputs))
589619
else:
590-
outputs = [output for output in outputs if output is not None]
620+
outputs = [
621+
output for output in outputs if output is not None]
591622
else:
592623
outputs = await end_relationship_resolver(in_between_objects, info)
593624
if not isinstance(outputs, collections.abc.Iterable):
@@ -683,7 +714,8 @@ def convert(type_: Any) -> Any:
683714
setattr(type_, key, field(resolver=val))
684715
generated_field_keys.append(key)
685716

686-
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
717+
self._handle_columns(
718+
mapper, type_, excluded_keys, generated_field_keys)
687719
relationship: RelationshipProperty
688720
for key, relationship in mapper.relationships.items():
689721
if (
@@ -805,7 +837,8 @@ def convert(type_: Any) -> Any:
805837
setattr(
806838
type_,
807839
attr,
808-
types.MethodType(func, type_), # type: ignore[arg-type]
840+
# type: ignore[arg-type]
841+
types.MethodType(func, type_),
809842
)
810843

811844
# Adjust types that inherit from other types/interfaces that implement Node
@@ -818,7 +851,8 @@ def convert(type_: Any) -> Any:
818851
setattr(
819852
type_,
820853
attr,
821-
types.MethodType(cast(classmethod, meth).__func__, type_),
854+
types.MethodType(
855+
cast(classmethod, meth).__func__, type_),
822856
)
823857

824858
# need to make fields that are already in the type
@@ -846,7 +880,8 @@ def convert(type_: Any) -> Any:
846880
model=model,
847881
),
848882
)
849-
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
883+
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
884+
generated_field_keys)
850885
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
851886
return mapped_type
852887

@@ -886,14 +921,16 @@ def _fix_annotation_namespaces(self) -> None:
886921
self.edge_types.values(),
887922
self.connection_types.values(),
888923
):
889-
strawberry_definition = get_object_definition(mapped_type, strict=True)
924+
strawberry_definition = get_object_definition(
925+
mapped_type, strict=True)
890926
for f in strawberry_definition.fields:
891927
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
892928
namespace = {}
893929
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
894930
namespace.update(
895931
sys.modules[
896-
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
932+
getattr(mapped_type,
933+
_ORIGINAL_TYPE_KEY).__module__
897934
].__dict__
898935
)
899936
namespace.update(self.mapped_types)
@@ -924,7 +961,8 @@ def _map_unmapped_relationships(self) -> None:
924961
if type_name not in self.mapped_interfaces:
925962
unmapped_interface_models.add(model)
926963
for model in unmapped_models:
927-
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
964+
self.type(model)(
965+
type(self.model_to_type_name(model), (object,), {}))
928966
for model in unmapped_interface_models:
929967
self.interface(model)(
930968
type(self.model_to_interface_name(model), (object,), {})

0 commit comments

Comments
 (0)