Skip to content

Commit 7e0aa2c

Browse files
committed
Add tests for polymorphic foreign key optimization
1 parent 47147b7 commit 7e0aa2c

File tree

4 files changed

+119
-4
lines changed

4 files changed

+119
-4
lines changed

tests/polymorphism/test_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_polymorphic_relation():
193193
"""
194194

195195
# Company, ContentType, base table, two subtables = 5 queries
196-
with assert_num_queries(1):
196+
with assert_num_queries(5):
197197
result = schema.execute_sync(query)
198198
assert not result.errors
199199
assert result.data == {
@@ -218,7 +218,6 @@ def test_polymorphic_relation():
218218
}
219219

220220

221-
222221
@pytest.mark.django_db(transaction=True)
223222
def test_polymorphic_nested_list():
224223
company = Company.objects.create(name="Company")

tests/polymorphism_custom/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
from django.db import models
22

33

4+
class Company(models.Model):
5+
name = models.CharField(max_length=100)
6+
main_project = models.ForeignKey('CustomPolyProject', null=True, blank=True, on_delete=models.CASCADE)
7+
8+
class Meta:
9+
ordering = ('name',)
10+
11+
412
class CustomPolyProject(models.Model):
13+
company = models.ForeignKey(Company, null=True, blank=True, on_delete=models.CASCADE, related_name='projects')
514
topic = models.CharField(max_length=30)
615
artist = models.CharField(max_length=30, blank=True)
716
supervisor = models.CharField(max_length=30, blank=True)

tests/polymorphism_custom/schema.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from strawberry_django.pagination import OffsetPaginated
99
from strawberry_django.relay import ListConnectionWithTotalCount
1010

11-
from .models import CustomPolyProject
11+
from .models import CustomPolyProject, Company
1212

1313

1414
@strawberry_django.interface(CustomPolyProject)
@@ -37,8 +37,16 @@ class ResearchProjectType(ProjectType):
3737
supervisor: strawberry.auto
3838

3939

40+
@strawberry_django.type(Company)
41+
class CompanyType:
42+
name: strawberry.auto
43+
main_project: ProjectType | None
44+
projects: list[ProjectType]
45+
46+
4047
@strawberry.type
4148
class Query:
49+
companies: list[CompanyType] = strawberry_django.field()
4250
projects: list[ProjectType] = strawberry_django.field()
4351
projects_paginated: list[ProjectType] = strawberry_django.field(pagination=True)
4452
projects_offset_paginated: OffsetPaginated[ProjectType] = (

tests/polymorphism_custom/test_optimizer.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from tests.utils import assert_num_queries
66

7-
from .models import CustomPolyProject
7+
from .models import CustomPolyProject, Company
88
from .schema import schema
99

1010

@@ -209,3 +209,102 @@ def test_polymorphic_interface_connection():
209209
],
210210
}
211211
}
212+
213+
214+
@pytest.mark.django_db(transaction=True)
215+
def test_polymorphic_relation():
216+
ap = CustomPolyProject.objects.create(topic="Art", artist="Artist")
217+
art_company = Company.objects.create(name="ArtCompany", main_project=ap)
218+
219+
rp = CustomPolyProject.objects.create(topic="Research", supervisor="Supervisor")
220+
research_company = Company.objects.create(name="ResearchCompany", main_project=rp)
221+
222+
query = """\
223+
query {
224+
companies {
225+
name
226+
mainProject {
227+
__typename
228+
topic
229+
... on ArtProjectType {
230+
artist
231+
}
232+
... on ResearchProjectType {
233+
supervisor
234+
}
235+
}
236+
}
237+
}
238+
"""
239+
240+
with assert_num_queries(2):
241+
result = schema.execute_sync(query)
242+
assert not result.errors
243+
assert result.data == {
244+
"companies": [
245+
{
246+
"name": art_company.name,
247+
"mainProject": {
248+
"__typename": "ArtProjectType",
249+
"topic": ap.topic,
250+
"artist": ap.artist,
251+
},
252+
},
253+
{
254+
"name": research_company.name,
255+
"mainProject": {
256+
"__typename": "ResearchProjectType",
257+
"topic": rp.topic,
258+
"supervisor": rp.supervisor,
259+
},
260+
}
261+
]
262+
}
263+
264+
265+
@pytest.mark.django_db(transaction=True)
266+
def test_polymorphic_nested_list():
267+
company = Company.objects.create(name="Company")
268+
ap = CustomPolyProject.objects.create(company=company, topic="Art", artist="Artist")
269+
rp = CustomPolyProject.objects.create(company=company, topic="Research", supervisor="Supervisor")
270+
271+
query = """\
272+
query {
273+
companies {
274+
name
275+
projects {
276+
__typename
277+
topic
278+
... on ArtProjectType {
279+
artist
280+
}
281+
... on ResearchProjectType {
282+
supervisor
283+
}
284+
}
285+
}
286+
}
287+
"""
288+
289+
with assert_num_queries(2):
290+
result = schema.execute_sync(query)
291+
assert not result.errors
292+
assert result.data == {
293+
"companies": [
294+
{
295+
"name": "Company",
296+
"projects": [
297+
{
298+
"__typename": "ArtProjectType",
299+
"topic": ap.topic,
300+
"artist": ap.artist,
301+
},
302+
{
303+
"__typename": "ResearchProjectType",
304+
"topic": rp.topic,
305+
"supervisor": rp.supervisor,
306+
},
307+
],
308+
}
309+
]
310+
}

0 commit comments

Comments
 (0)