|
1 |
| -from tests.testmodels import SourceFields, StraightFields |
| 1 | +from tests.testmodels import DoubleFK, Event, SourceFields, StraightFields, Tournament |
2 | 2 | from tortoise.contrib import test
|
3 |
| -from tortoise.exceptions import IncompleteInstanceError |
| 3 | +from tortoise.functions import Count |
| 4 | +from tortoise.exceptions import FieldError, IncompleteInstanceError |
4 | 5 |
|
5 | 6 |
|
6 | 7 | class TestOnlyStraight(test.TestCase):
|
@@ -64,3 +65,230 @@ async def asyncSetUp(self) -> None:
|
64 | 65 | await super().asyncSetUp()
|
65 | 66 | self.model = SourceFields # type: ignore
|
66 | 67 | self.instance = await self.model.create(chars="Test")
|
| 68 | + |
| 69 | + |
| 70 | +class TestOnlyRecursive(test.TestCase): |
| 71 | + async def test_one_level(self): |
| 72 | + left_1st_lvl = await DoubleFK.create(name="1st") |
| 73 | + root = await DoubleFK.create(name="root", left=left_1st_lvl) |
| 74 | + |
| 75 | + ret = ( |
| 76 | + await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() |
| 77 | + ) |
| 78 | + self.assertIsNotNone(ret) |
| 79 | + with self.assertRaises(AttributeError): |
| 80 | + _ = ret.id |
| 81 | + self.assertEqual(ret.name, "root") |
| 82 | + self.assertEqual(ret.left.name, "1st") |
| 83 | + with self.assertRaises(AttributeError): |
| 84 | + _ = ret.left.id |
| 85 | + with self.assertRaises(AttributeError): |
| 86 | + _ = ret.right |
| 87 | + |
| 88 | + async def test_two_levels(self): |
| 89 | + left_2nd_lvl = await DoubleFK.create(name="second leaf") |
| 90 | + left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) |
| 91 | + root = await DoubleFK.create(name="root", left=left_1st_lvl) |
| 92 | + |
| 93 | + ret = ( |
| 94 | + await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() |
| 95 | + ) |
| 96 | + self.assertIsNotNone(ret) |
| 97 | + with self.assertRaises(AttributeError): |
| 98 | + _ = ret.id |
| 99 | + self.assertEqual(ret.name, "root") |
| 100 | + self.assertEqual(ret.left.name, "1st") |
| 101 | + with self.assertRaises(AttributeError): |
| 102 | + _ = ret.left.id |
| 103 | + self.assertEqual(ret.left.left.name, "second leaf") |
| 104 | + |
| 105 | + async def test_two_levels_reverse_argument_order(self): |
| 106 | + left_2nd_lvl = await DoubleFK.create(name="second leaf") |
| 107 | + left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) |
| 108 | + root = await DoubleFK.create(name="root", left=left_1st_lvl) |
| 109 | + |
| 110 | + ret = ( |
| 111 | + await DoubleFK.filter(pk=root.pk).only("left__left__name", "left__name", "name").first() |
| 112 | + ) |
| 113 | + self.assertIsNotNone(ret) |
| 114 | + with self.assertRaises(AttributeError): |
| 115 | + _ = ret.id |
| 116 | + self.assertEqual(ret.name, "root") |
| 117 | + self.assertEqual(ret.left.name, "1st") |
| 118 | + with self.assertRaises(AttributeError): |
| 119 | + _ = ret.left.id |
| 120 | + self.assertEqual(ret.left.left.name, "second leaf") |
| 121 | + |
| 122 | + |
| 123 | +class TestOnlyRelated(test.TestCase): |
| 124 | + async def test_related_one_level(self): |
| 125 | + tournament = await Tournament.create(name="New Tournament", desc="New Description") |
| 126 | + await Event.create(name="Event 1", tournament=tournament) |
| 127 | + await Event.create(name="Event 2", tournament=tournament) |
| 128 | + |
| 129 | + ret = ( |
| 130 | + await Event.filter(tournament=tournament) |
| 131 | + .only("name", "tournament__name") |
| 132 | + .order_by("name") |
| 133 | + ) |
| 134 | + self.assertEqual(len(ret), 2) |
| 135 | + self.assertEqual(ret[0].name, "Event 1") |
| 136 | + with self.assertRaises(AttributeError): |
| 137 | + _ = ret[0].alias |
| 138 | + self.assertEqual(ret[1].name, "Event 2") |
| 139 | + with self.assertRaises(AttributeError): |
| 140 | + _ = ret[1].alias |
| 141 | + self.assertEqual(ret[0].tournament.name, "New Tournament") |
| 142 | + with self.assertRaises(AttributeError): |
| 143 | + _ = ret[0].tournament.id |
| 144 | + with self.assertRaises(AttributeError): |
| 145 | + _ = ret[0].tournament.desc |
| 146 | + |
| 147 | + async def test_related_one_level_reversed_argument_order(self): |
| 148 | + tournament = await Tournament.create(name="New Tournament", desc="New Description") |
| 149 | + await Event.create(name="Event 1", tournament=tournament) |
| 150 | + await Event.create(name="Event 2", tournament=tournament) |
| 151 | + |
| 152 | + ret = ( |
| 153 | + await Event.filter(tournament=tournament) |
| 154 | + .only("tournament__name", "name") |
| 155 | + .order_by("name") |
| 156 | + ) |
| 157 | + self.assertEqual(len(ret), 2) |
| 158 | + self.assertEqual(ret[0].name, "Event 1") |
| 159 | + self.assertEqual(ret[0].tournament.name, "New Tournament") |
| 160 | + |
| 161 | + async def test_just_related(self): |
| 162 | + tournament = await Tournament.create(name="New Tournament", desc="New Description") |
| 163 | + await Event.create(name="Event 1", tournament=tournament) |
| 164 | + await Event.create(name="Event 2", tournament=tournament) |
| 165 | + |
| 166 | + ret = ( |
| 167 | + await Event.filter(tournament=tournament) |
| 168 | + .only("tournament__name") |
| 169 | + .order_by("name") |
| 170 | + .all() |
| 171 | + ) |
| 172 | + self.assertEqual(len(ret), 2) |
| 173 | + self.assertEqual(ret[0].tournament.name, "New Tournament") |
| 174 | + self.assertEqual(ret[1].tournament.name, "New Tournament") |
| 175 | + |
| 176 | + |
| 177 | +class TestOnlyAdvanced(test.TestCase): |
| 178 | + async def asyncSetUp(self) -> None: |
| 179 | + await super().asyncSetUp() |
| 180 | + self.tournament = await Tournament.create(name="Tournament A", desc="Description A") |
| 181 | + self.event1 = await Event.create(name="Event 1", tournament=self.tournament) |
| 182 | + self.event2 = await Event.create(name="Event 2", tournament=self.tournament) |
| 183 | + |
| 184 | + async def test_exclude(self): |
| 185 | + """Test .only() combined with .exclude()""" |
| 186 | + events = await Event.filter(tournament=self.tournament).exclude(name="Event 2").only("name") |
| 187 | + self.assertEqual(len(events), 1) |
| 188 | + self.assertEqual(events[0].name, "Event 1") |
| 189 | + with self.assertRaises(AttributeError): |
| 190 | + _ = events[0].modified |
| 191 | + |
| 192 | + async def test_limit(self): |
| 193 | + """Test .only() combined with .limit()""" |
| 194 | + events = await Event.all().only("name").limit(1) |
| 195 | + self.assertEqual(len(events), 1) |
| 196 | + self.assertEqual(events[0].name, "Event 1") # Assumes ordering by PK |
| 197 | + with self.assertRaises(AttributeError): |
| 198 | + _ = events[0].modified |
| 199 | + |
| 200 | + async def test_distinct(self): |
| 201 | + """Test .only() combined with .distinct()""" |
| 202 | + # Create duplicate event names |
| 203 | + await Event.create(name="Event 1", tournament=self.tournament) |
| 204 | + |
| 205 | + events = await Event.all().only("name").distinct() |
| 206 | + # Should only have 2 distinct event names |
| 207 | + self.assertEqual(len(events), 2) |
| 208 | + event_names = {e.name for e in events} |
| 209 | + self.assertEqual(event_names, {"Event 1", "Event 2"}) |
| 210 | + |
| 211 | + async def test_values(self): |
| 212 | + """Test .only() combined with .values()""" |
| 213 | + with self.assertRaises(ValueError, msg="values() cannot be used with .only()"): |
| 214 | + await Event.all().only("name").values("name") |
| 215 | + |
| 216 | + async def test_pk_field(self): |
| 217 | + """Test .only() with just the primary key field""" |
| 218 | + tournament = await Tournament.first().only("id") |
| 219 | + self.assertIsNotNone(tournament.id) |
| 220 | + with self.assertRaises(AttributeError): |
| 221 | + _ = tournament.name |
| 222 | + |
| 223 | + async def test_empty(self): |
| 224 | + """Test .only() with no fields (should raise an error)""" |
| 225 | + with self.assertRaises(ValueError): |
| 226 | + await Event.all().only() |
| 227 | + |
| 228 | + async def test_annotate(self): |
| 229 | + tournaments = await Tournament.annotate(event_count=Count("events")).only( |
| 230 | + "name", "event_count" |
| 231 | + ) |
| 232 | + |
| 233 | + self.assertEqual(tournaments[0].name, "Tournament A") |
| 234 | + self.assertEqual(tournaments[0].event_count, 2) |
| 235 | + with self.assertRaises(AttributeError): |
| 236 | + _ = tournaments[0].desc |
| 237 | + |
| 238 | + async def test_nonexistent_field(self): |
| 239 | + """Test .only() with a field that doesn't exist""" |
| 240 | + with self.assertRaises(FieldError): |
| 241 | + await Event.all().only("nonexistent_field").all() |
| 242 | + |
| 243 | + async def test_join_in_filter(self): |
| 244 | + event = await Event.filter(name="Event 1").only("name").first() |
| 245 | + self.assertEqual(event.name, "Event 1") |
| 246 | + with self.assertRaises(AttributeError): |
| 247 | + _ = event.tournament |
| 248 | + |
| 249 | + event = await Event.filter(tournament__name="Tournament A").only("name").first() |
| 250 | + self.assertEqual(event.name, "Event 1") |
| 251 | + with self.assertRaises(AttributeError): |
| 252 | + _ = event.tournament |
| 253 | + |
| 254 | + event = ( |
| 255 | + await Event.filter(tournament__name="Tournament A") |
| 256 | + .only("name", "tournament__name") |
| 257 | + .first() |
| 258 | + ) |
| 259 | + self.assertEqual(event.name, "Event 1") |
| 260 | + self.assertEqual(event.tournament.name, "Tournament A") |
| 261 | + |
| 262 | + async def test_join_in_order_by(self): |
| 263 | + events = await Event.all().order_by("name").only("name") |
| 264 | + self.assertEqual(events[0].name, "Event 1") |
| 265 | + with self.assertRaises(AttributeError): |
| 266 | + _ = events[0].tournament |
| 267 | + |
| 268 | + events = await Event.all().order_by("tournament__name", "name").only("name") |
| 269 | + self.assertEqual(events[0].name, "Event 1") |
| 270 | + with self.assertRaises(AttributeError): |
| 271 | + _ = events[0].tournament |
| 272 | + |
| 273 | + events = ( |
| 274 | + await Event.all().order_by("tournament__name", "name").only("name", "tournament__name") |
| 275 | + ) |
| 276 | + self.assertEqual(events[0].name, "Event 1") |
| 277 | + self.assertEqual(events[0].tournament.name, "Tournament A") |
| 278 | + |
| 279 | + async def test_select_related(self): |
| 280 | + """Test .only() with .select_related() for basic functionality""" |
| 281 | + event = ( |
| 282 | + await Event.filter(name="Event 1") |
| 283 | + .select_related("tournament") |
| 284 | + .only("name", "tournament__name") |
| 285 | + .first() |
| 286 | + ) |
| 287 | + |
| 288 | + self.assertEqual(event.name, "Event 1") |
| 289 | + self.assertEqual(event.tournament.name, "Tournament A") |
| 290 | + |
| 291 | + with self.assertRaises(AttributeError): |
| 292 | + _ = event.id |
| 293 | + with self.assertRaises(AttributeError): |
| 294 | + _ = event.tournament.id |
0 commit comments