Skip to content

Commit 401cd65

Browse files
committed
fix: now query can pickup related_model and self_model id
1 parent 0cd732d commit 401cd65

File tree

3 files changed

+157
-72
lines changed

3 files changed

+157
-72
lines changed

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Union,
1313
)
1414

15-
from sqlalchemy import select, tuple_
15+
from sqlalchemy import select, tuple_, label
1616
from sqlalchemy.engine.base import Connection
1717
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
1818
from sqlalchemy.orm import RelationshipProperty, Session
@@ -45,12 +45,16 @@ def __init__(
4545
"One of bind or async_bind_factory must be set for loader to function properly."
4646
)
4747

48-
async def _scalars_all(self, *args, **kwargs):
48+
async def _scalars_all(self, *args, disabled_optimization_to_secondary_tables=False, **kwargs):
4949
if self._async_bind_factory:
5050
async with self._async_bind_factory() as bind:
51+
if disabled_optimization_to_secondary_tables is True:
52+
return (await bind.execute(*args, **kwargs)).all()
5153
return (await bind.scalars(*args, **kwargs)).all()
5254
else:
5355
assert self._bind is not None
56+
if disabled_optimization_to_secondary_tables is True:
57+
return self._bind.execute(*args, **kwargs).all()
5458
return self._bind.scalars(*args, **kwargs).all()
5559

5660
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
@@ -72,23 +76,82 @@ async def load_fn(keys: List[Tuple]) -> List[Any]:
7276
else:
7377
# Use another query when relationship uses a secondary table
7478
# *[remote[1] for remote in relationship.local_remote_pairs or []]
79+
self_model = relationship.parent.entity
80+
81+
self_model_key_label = relationship.local_remote_pairs[0][1].key
82+
related_model_key_label = relationship.local_remote_pairs[1][1].key
83+
84+
self_model_key = relationship.local_remote_pairs[0][0].key
7585
# breakpoint()
76-
# remote_to_use = relationship.local_remote_pairs[0][1]
77-
# keys = tuple([item[0] for item in keys])
86+
# Gets the
87+
remote_to_use = relationship.local_remote_pairs[0][1]
88+
query_keys = tuple([item[0] for item in keys])
89+
breakpoint()
7890
query = (
79-
select(related_model)
80-
.join(relationship.secondary, relationship.secondaryjoin)
91+
# select(related_model)
92+
select(
93+
label(self_model_key_label, getattr(
94+
self_model, self_model_key)),
95+
related_model
96+
)
97+
# .join(
98+
# related_model,
99+
# getattr(relationship.secondary.c, related_model_key_label) == getattr(
100+
# related_model, related_model_key)
101+
# )
102+
# .join(
103+
# relationship.secondary,
104+
# getattr(relationship.secondary.c, self_model_key_label) == getattr(
105+
# self_model, self_model_key)
106+
# )
107+
# .join(
108+
# relationship.secondary,
109+
# getattr(relationship.secondary.c, self_model_key_label) == getattr(
110+
# self_model, self_model_key)
111+
# )
112+
.join(
113+
relationship.secondary, # Join the secondary table
114+
getattr(relationship.secondary.c, related_model_key_label) == related_model.id # Match department_id
115+
)
116+
.join(
117+
self_model, # Join the Employee table
118+
getattr(relationship.secondary.c, self_model_key_label) == self_model.id # Match employee_id
119+
)
81120
.filter(
82-
# emote_to_use.in_(keys)
83-
tuple_(
84-
*[remote[1] for remote in relationship.local_remote_pairs or []]
85-
).in_(keys)
121+
remote_to_use.in_(query_keys)
86122
)
87123
)
124+
# query = (
125+
# # select(related_model)
126+
# select(
127+
# related_model,
128+
# label(self_model_key_label, getattr(self_model, self_model_key))
129+
# )
130+
# .join(relationship.secondary, relationship.secondaryjoin)
131+
# .filter(
132+
# remote_to_use.in_(query_keys)
133+
# )
134+
# )
135+
136+
# query = (
137+
# select(related_model)
138+
# .join(relationship.secondary, relationship.secondaryjoin)
139+
# .filter(
140+
# # emote_to_use.in_(keys)
141+
# tuple_(
142+
# *[remote[1] for remote in relationship.local_remote_pairs or []]
143+
# ).in_(keys)
144+
# )
145+
# )
88146

89147
if relationship.order_by:
90148
query = query.order_by(*relationship.order_by)
91-
rows = await self._scalars_all(query)
149+
150+
if relationship.secondary is not None:
151+
# We need get the self_model values too, so we need to remove the slqalchemy optimization that returns only the related_model values, this is needed because we use the keys var to match the related_model and the self_model
152+
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
153+
else:
154+
rows = await self._scalars_all(query)
92155

93156
def group_by_remote_key(row: Any) -> Tuple:
94157
if relationship.secondary is None:
@@ -104,6 +167,24 @@ def group_by_remote_key(row: Any) -> Tuple:
104167
# breakpoint()
105168
related_model_table = relationship.entity.entity.__table__
106169
# breakpoint()
170+
# return tuple(
171+
# [
172+
# getattr(row, remote[0].key)
173+
# for remote in relationship.local_remote_pairs or []
174+
# if remote[0].key is not None and remote[0].table == related_model_table
175+
# ]
176+
# )
177+
result = []
178+
for remote in relationship.local_remote_pairs or []:
179+
if remote[0].key is not None and relationship.local_remote_pairs[1][0].table == related_model_table:
180+
result.extend(
181+
[
182+
183+
getattr(row, remote[0].key)
184+
185+
]
186+
)
187+
breakpoint()
107188
return tuple(
108189
[
109190
getattr(row, remote[0].key)
@@ -113,11 +194,11 @@ def group_by_remote_key(row: Any) -> Tuple:
113194
)
114195

115196
grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
116-
# breakpoint()
197+
breakpoint()
117198
for row in rows:
118199
grouped_keys[group_by_remote_key(row)].append(row)
119200

120-
# breakpoint()
201+
breakpoint()
121202
if relationship.uselist:
122203
return [grouped_keys[key] for key in keys]
123204
else:

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -519,18 +519,20 @@ async def resolve(self, info: Info):
519519
else:
520520
# If has a secondary table, gets only the first id since the other id cannot be get without a query
521521
# breakpoint()
522-
# local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
523-
# 0][0]
524-
# relationship_key = tuple(
525-
# [getattr(self, local_remote_pairs_secondary_table_local.key),]
526-
# )
522+
local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[0][0]
527523
relationship_key = tuple(
528524
[
529-
getattr(self, local.key)
530-
for local, _ in relationship.local_remote_pairs or []
531-
if local.key
525+
getattr(self, local_remote_pairs_secondary_table_local.key),
532526
]
533527
)
528+
529+
# relationship_key = tuple(
530+
# [
531+
# getattr(self, local.key)
532+
# for local, _ in relationship.local_remote_pairs or []
533+
# if local.key
534+
# ]
535+
# )
534536
# breakpoint()
535537

536538
if any(item is None for item in relationship_key):

tests/relay/test_connection.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ class Department(base):
799799
return Employee, Department
800800

801801

802+
@pytest.mark.asyncio
802803
async def test_query_with_secondary_table(
803804
secondary_tables,
804805
base,
@@ -869,7 +870,6 @@ class Query:
869870
)
870871
})
871872
assert result.errors is None
872-
# breakpoint()
873873
assert result.data == {
874874
'employees': {
875875
'edges': [
@@ -920,6 +920,7 @@ class Query:
920920
}
921921

922922

923+
@pytest.mark.asyncio
923924
async def test_query_with_secondary_table_without_list_connection(
924925
secondary_tables,
925926
base,
@@ -992,12 +993,12 @@ async def employees(self) -> List[Employee]:
992993
assert result.errors is None
993994
# breakpoint()
994995
assert result.data == {
995-
'employees': [
996-
{
997-
'id': 1,
998-
'name': 'John',
999-
'role': 'Developer',
1000-
'department': {
996+
'employees': [
997+
{
998+
'id': 1,
999+
'name': 'John',
1000+
'role': 'Developer',
1001+
'department': {
10011002
'edges': [
10021003
{
10031004
'node': {
@@ -1011,28 +1012,29 @@ async def employees(self) -> List[Employee]:
10111012
}
10121013
}
10131014
]
1014-
}
1015-
},
1016-
{
1017-
'id': 2,
1018-
'name': 'Bill',
1019-
'role': 'Doctor',
1020-
'department': {
1015+
}
1016+
},
1017+
{
1018+
'id': 2,
1019+
'name': 'Bill',
1020+
'role': 'Doctor',
1021+
'department': {
10211022
'edges': []
1022-
}
1023-
},
1024-
{
1025-
'id': 3,
1026-
'name': 'Maria',
1027-
'role': 'Teacher',
1028-
'department': {
1023+
}
1024+
},
1025+
{
1026+
'id': 3,
1027+
'name': 'Maria',
1028+
'role': 'Teacher',
1029+
'department': {
10291030
'edges': []
1030-
}
10311031
}
1032-
]
1033-
}
1032+
}
1033+
]
1034+
}
10341035

10351036

1037+
@pytest.mark.asyncio
10361038
async def test_query_with_secondary_table_with_values_with_different_ids(
10371039
secondary_tables,
10381040
base,
@@ -1107,14 +1109,14 @@ async def employees(self) -> List[Employee]:
11071109
)
11081110
})
11091111
assert result.errors is None
1110-
# breakpoint()
1112+
breakpoint()
11111113
assert result.data == {
1112-
'employees': [
1113-
{
1114-
'id': 1,
1115-
'name': 'John',
1116-
'role': 'Developer',
1117-
'department': {
1114+
'employees': [
1115+
{
1116+
'id': 1,
1117+
'name': 'John',
1118+
'role': 'Developer',
1119+
'department': {
11181120
'edges': [
11191121
{
11201122
'node': {
@@ -1128,31 +1130,31 @@ async def employees(self) -> List[Employee]:
11281130
}
11291131
}
11301132
]
1131-
}
1132-
},
1133-
{
1134-
'id': 2,
1135-
'name': 'Bill',
1136-
'role': 'Doctor',
1137-
'department': {
1133+
}
1134+
},
1135+
{
1136+
'id': 2,
1137+
'name': 'Bill',
1138+
'role': 'Doctor',
1139+
'department': {
11381140
'edges': []
1139-
}
1140-
},
1141-
{
1142-
'id': 3,
1143-
'name': 'Maria',
1144-
'role': 'Teacher',
1145-
'department': {
1141+
}
1142+
},
1143+
{
1144+
'id': 3,
1145+
'name': 'Maria',
1146+
'role': 'Teacher',
1147+
'department': {
11461148
'edges': []
1147-
}
11481149
}
1149-
]
1150-
}
1151-
1152-
1150+
}
1151+
]
1152+
}
11531153

11541154

11551155
# TODO
11561156
# test with different ids
1157+
# test with foreinkey different than id
11571158
# add a test on Loader to see
1158-
# Add test with query by secondary id
1159+
# Add test with query by secondary id]
1160+
# try syncronous

0 commit comments

Comments
 (0)