Skip to content

Commit 21463c7

Browse files
authored
Support relations in .only() (#1923)
* Make expand_field_expression reusable * .only() supports relations * Remove redundant check * Prohibit using .values() and .values_list() with .only() * Add tests * Add a comment * Update CHANGELOG * CR feedback
1 parent 2aa49ad commit 21463c7

File tree

7 files changed

+429
-96
lines changed

7 files changed

+429
-96
lines changed

CHANGELOG.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@ Changelog
66

77
.. rst-class:: emphasize-children
88

9-
0.24
9+
0.25
1010
====
1111

12-
0.24.3 (unreleased)
12+
0.25.0 (unreleased)
1313
------
1414
Changed
1515
^^^^^^^
1616
- Skip database selection if the router is not configured to improve performance (#1915)
17+
- `.values()`, `.values_list()` and `.only()` cannot be used together (#1923)
18+
19+
Added
20+
^^^^^
21+
- `.only` supports selecting related fields, e.g. `.only("related__field")` (#1923)
22+
23+
24+
0.24
25+
====
1726

1827
0.24.2
1928
------

tests/test_aggregation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from tortoise.contrib import test
1313
from tortoise.contrib.test.condition import In
14-
from tortoise.exceptions import ConfigurationError
14+
from tortoise.exceptions import FieldError
1515
from tortoise.expressions import F, Q
1616
from tortoise.functions import Avg, Coalesce, Concat, Count, Lower, Max, Min, Sum, Trim
1717

@@ -60,7 +60,7 @@ async def test_aggregation(self):
6060
event_with_annotation.tournament_id,
6161
)
6262

63-
with self.assertRaisesRegex(ConfigurationError, "name__id not resolvable"):
63+
with self.assertRaisesRegex(FieldError, "name__id not resolvable"):
6464
await Event.all().annotate(tournament_test_id=Sum("name__id")).first()
6565

6666
async def test_nested_aggregation_in_annotation(self):

tests/test_only.py

Lines changed: 230 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from tests.testmodels import SourceFields, StraightFields
1+
from tests.testmodels import DoubleFK, Event, SourceFields, StraightFields, Tournament
22
from tortoise.contrib import test
3-
from tortoise.exceptions import IncompleteInstanceError
3+
from tortoise.functions import Count
4+
from tortoise.exceptions import FieldError, IncompleteInstanceError
45

56

67
class TestOnlyStraight(test.TestCase):
@@ -64,3 +65,230 @@ async def asyncSetUp(self) -> None:
6465
await super().asyncSetUp()
6566
self.model = SourceFields # type: ignore
6667
self.instance = await self.model.create(chars="Test")
68+
69+
70+
class TestOnlyRecursive(test.TestCase):
71+
async def test_one_level(self):
72+
left_1st_lvl = await DoubleFK.create(name="1st")
73+
root = await DoubleFK.create(name="root", left=left_1st_lvl)
74+
75+
ret = (
76+
await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first()
77+
)
78+
self.assertIsNotNone(ret)
79+
with self.assertRaises(AttributeError):
80+
_ = ret.id
81+
self.assertEqual(ret.name, "root")
82+
self.assertEqual(ret.left.name, "1st")
83+
with self.assertRaises(AttributeError):
84+
_ = ret.left.id
85+
with self.assertRaises(AttributeError):
86+
_ = ret.right
87+
88+
async def test_two_levels(self):
89+
left_2nd_lvl = await DoubleFK.create(name="second leaf")
90+
left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl)
91+
root = await DoubleFK.create(name="root", left=left_1st_lvl)
92+
93+
ret = (
94+
await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first()
95+
)
96+
self.assertIsNotNone(ret)
97+
with self.assertRaises(AttributeError):
98+
_ = ret.id
99+
self.assertEqual(ret.name, "root")
100+
self.assertEqual(ret.left.name, "1st")
101+
with self.assertRaises(AttributeError):
102+
_ = ret.left.id
103+
self.assertEqual(ret.left.left.name, "second leaf")
104+
105+
async def test_two_levels_reverse_argument_order(self):
106+
left_2nd_lvl = await DoubleFK.create(name="second leaf")
107+
left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl)
108+
root = await DoubleFK.create(name="root", left=left_1st_lvl)
109+
110+
ret = (
111+
await DoubleFK.filter(pk=root.pk).only("left__left__name", "left__name", "name").first()
112+
)
113+
self.assertIsNotNone(ret)
114+
with self.assertRaises(AttributeError):
115+
_ = ret.id
116+
self.assertEqual(ret.name, "root")
117+
self.assertEqual(ret.left.name, "1st")
118+
with self.assertRaises(AttributeError):
119+
_ = ret.left.id
120+
self.assertEqual(ret.left.left.name, "second leaf")
121+
122+
123+
class TestOnlyRelated(test.TestCase):
124+
async def test_related_one_level(self):
125+
tournament = await Tournament.create(name="New Tournament", desc="New Description")
126+
await Event.create(name="Event 1", tournament=tournament)
127+
await Event.create(name="Event 2", tournament=tournament)
128+
129+
ret = (
130+
await Event.filter(tournament=tournament)
131+
.only("name", "tournament__name")
132+
.order_by("name")
133+
)
134+
self.assertEqual(len(ret), 2)
135+
self.assertEqual(ret[0].name, "Event 1")
136+
with self.assertRaises(AttributeError):
137+
_ = ret[0].alias
138+
self.assertEqual(ret[1].name, "Event 2")
139+
with self.assertRaises(AttributeError):
140+
_ = ret[1].alias
141+
self.assertEqual(ret[0].tournament.name, "New Tournament")
142+
with self.assertRaises(AttributeError):
143+
_ = ret[0].tournament.id
144+
with self.assertRaises(AttributeError):
145+
_ = ret[0].tournament.desc
146+
147+
async def test_related_one_level_reversed_argument_order(self):
148+
tournament = await Tournament.create(name="New Tournament", desc="New Description")
149+
await Event.create(name="Event 1", tournament=tournament)
150+
await Event.create(name="Event 2", tournament=tournament)
151+
152+
ret = (
153+
await Event.filter(tournament=tournament)
154+
.only("tournament__name", "name")
155+
.order_by("name")
156+
)
157+
self.assertEqual(len(ret), 2)
158+
self.assertEqual(ret[0].name, "Event 1")
159+
self.assertEqual(ret[0].tournament.name, "New Tournament")
160+
161+
async def test_just_related(self):
162+
tournament = await Tournament.create(name="New Tournament", desc="New Description")
163+
await Event.create(name="Event 1", tournament=tournament)
164+
await Event.create(name="Event 2", tournament=tournament)
165+
166+
ret = (
167+
await Event.filter(tournament=tournament)
168+
.only("tournament__name")
169+
.order_by("name")
170+
.all()
171+
)
172+
self.assertEqual(len(ret), 2)
173+
self.assertEqual(ret[0].tournament.name, "New Tournament")
174+
self.assertEqual(ret[1].tournament.name, "New Tournament")
175+
176+
177+
class TestOnlyAdvanced(test.TestCase):
178+
async def asyncSetUp(self) -> None:
179+
await super().asyncSetUp()
180+
self.tournament = await Tournament.create(name="Tournament A", desc="Description A")
181+
self.event1 = await Event.create(name="Event 1", tournament=self.tournament)
182+
self.event2 = await Event.create(name="Event 2", tournament=self.tournament)
183+
184+
async def test_exclude(self):
185+
"""Test .only() combined with .exclude()"""
186+
events = await Event.filter(tournament=self.tournament).exclude(name="Event 2").only("name")
187+
self.assertEqual(len(events), 1)
188+
self.assertEqual(events[0].name, "Event 1")
189+
with self.assertRaises(AttributeError):
190+
_ = events[0].modified
191+
192+
async def test_limit(self):
193+
"""Test .only() combined with .limit()"""
194+
events = await Event.all().only("name").limit(1)
195+
self.assertEqual(len(events), 1)
196+
self.assertEqual(events[0].name, "Event 1") # Assumes ordering by PK
197+
with self.assertRaises(AttributeError):
198+
_ = events[0].modified
199+
200+
async def test_distinct(self):
201+
"""Test .only() combined with .distinct()"""
202+
# Create duplicate event names
203+
await Event.create(name="Event 1", tournament=self.tournament)
204+
205+
events = await Event.all().only("name").distinct()
206+
# Should only have 2 distinct event names
207+
self.assertEqual(len(events), 2)
208+
event_names = {e.name for e in events}
209+
self.assertEqual(event_names, {"Event 1", "Event 2"})
210+
211+
async def test_values(self):
212+
"""Test .only() combined with .values()"""
213+
with self.assertRaises(ValueError, msg="values() cannot be used with .only()"):
214+
await Event.all().only("name").values("name")
215+
216+
async def test_pk_field(self):
217+
"""Test .only() with just the primary key field"""
218+
tournament = await Tournament.first().only("id")
219+
self.assertIsNotNone(tournament.id)
220+
with self.assertRaises(AttributeError):
221+
_ = tournament.name
222+
223+
async def test_empty(self):
224+
"""Test .only() with no fields (should raise an error)"""
225+
with self.assertRaises(ValueError):
226+
await Event.all().only()
227+
228+
async def test_annotate(self):
229+
tournaments = await Tournament.annotate(event_count=Count("events")).only(
230+
"name", "event_count"
231+
)
232+
233+
self.assertEqual(tournaments[0].name, "Tournament A")
234+
self.assertEqual(tournaments[0].event_count, 2)
235+
with self.assertRaises(AttributeError):
236+
_ = tournaments[0].desc
237+
238+
async def test_nonexistent_field(self):
239+
"""Test .only() with a field that doesn't exist"""
240+
with self.assertRaises(FieldError):
241+
await Event.all().only("nonexistent_field").all()
242+
243+
async def test_join_in_filter(self):
244+
event = await Event.filter(name="Event 1").only("name").first()
245+
self.assertEqual(event.name, "Event 1")
246+
with self.assertRaises(AttributeError):
247+
_ = event.tournament
248+
249+
event = await Event.filter(tournament__name="Tournament A").only("name").first()
250+
self.assertEqual(event.name, "Event 1")
251+
with self.assertRaises(AttributeError):
252+
_ = event.tournament
253+
254+
event = (
255+
await Event.filter(tournament__name="Tournament A")
256+
.only("name", "tournament__name")
257+
.first()
258+
)
259+
self.assertEqual(event.name, "Event 1")
260+
self.assertEqual(event.tournament.name, "Tournament A")
261+
262+
async def test_join_in_order_by(self):
263+
events = await Event.all().order_by("name").only("name")
264+
self.assertEqual(events[0].name, "Event 1")
265+
with self.assertRaises(AttributeError):
266+
_ = events[0].tournament
267+
268+
events = await Event.all().order_by("tournament__name", "name").only("name")
269+
self.assertEqual(events[0].name, "Event 1")
270+
with self.assertRaises(AttributeError):
271+
_ = events[0].tournament
272+
273+
events = (
274+
await Event.all().order_by("tournament__name", "name").only("name", "tournament__name")
275+
)
276+
self.assertEqual(events[0].name, "Event 1")
277+
self.assertEqual(events[0].tournament.name, "Tournament A")
278+
279+
async def test_select_related(self):
280+
"""Test .only() with .select_related() for basic functionality"""
281+
event = (
282+
await Event.filter(name="Event 1")
283+
.select_related("tournament")
284+
.only("name", "tournament__name")
285+
.first()
286+
)
287+
288+
self.assertEqual(event.name, "Event 1")
289+
self.assertEqual(event.tournament.name, "Tournament A")
290+
291+
with self.assertRaises(AttributeError):
292+
_ = event.id
293+
with self.assertRaises(AttributeError):
294+
_ = event.tournament.id

tortoise/backends/base/executor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,15 @@ async def execute_select(
114114
for model, index, *__, full_path in self.select_related_idx[1:]:
115115
(*path, attr) = full_path
116116
related_items = row_items[current_idx : current_idx + index]
117-
if not any((v for _, v in related_items)):
118-
obj = None
119-
else:
117+
if any(v for _, v in related_items):
120118
obj = model._init_from_db(**{k.split(".")[1]: v for k, v in related_items})
119+
elif index == 0:
120+
# 0 signals that an empty "filler" object should be created in the case
121+
# where a field of related model is selected but model itself isn't,
122+
# e.g. .only("relatedmodel__field")
123+
obj = model._init_from_db()
124+
else:
125+
obj = None
121126
target = instances.get(tuple(path))
122127
if target is not None:
123128
setattr(target, f"_{attr}", obj)

tortoise/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,9 @@ async def refresh_from_db(
10661066
if not self._saved_in_db:
10671067
raise OperationalError("Can't refresh unpersisted record")
10681068
db = using_db or self._choose_db()
1069-
qs = QuerySet(self.__class__).using_db(db).only(*(fields or []))
1069+
qs = QuerySet(self.__class__).using_db(db)
1070+
if fields:
1071+
qs = qs.only(*fields)
10701072
obj = await qs.get(pk=self.pk)
10711073

10721074
for field in fields or self._meta.db_fields:

0 commit comments

Comments
 (0)