Skip to content

Commit 7f077c1

Browse files
authored
Parametrize SELECT queries (#1777)
* Parametrize SELECT queries * Make sure _execute() uses the same query as returned from sql() * Parametrize .values, .values_list, .exists and .count queries * Fix Postgres issues * Add params_inline arg to QuerySet.sql() * Use pypika-tortoise 0.3.0
1 parent 916d6cb commit 7f077c1

File tree

23 files changed

+492
-231
lines changed

23 files changed

+492
-231
lines changed

poetry.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ classifiers = [
3636

3737
[tool.poetry.dependencies]
3838
python = "^3.8"
39-
pypika-tortoise = "^0.2.2"
39+
pypika-tortoise = "^0.3.0"
4040
iso8601 = "^2.1.0"
4141
aiosqlite = ">=0.16.0, <0.21.0"
4242
pytz = "*"

tests/contrib/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async def test_mysql_func_rand(self):
2121
@test.requireCapability(dialect="mysql")
2222
async def test_mysql_func_rand_with_seed(self):
2323
sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql()
24-
expected_sql = "SELECT `intnum` `intnum`,RAND(0) `randnum` FROM `intfields`"
24+
expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`"
2525
self.assertEqual(sql, expected_sql)
2626

2727
@test.requireCapability(dialect="postgres")

tests/test_case_when.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ async def asyncSetUp(self):
1414

1515
async def test_single_when(self):
1616
category = Case(When(intnum__gte=8, then="big"), default="default")
17-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
17+
sql = (
18+
IntFields.all()
19+
.annotate(category=category)
20+
.values("intnum", "category")
21+
.sql(params_inline=True)
22+
)
1823

1924
dialect = self.db.schema_generator.DIALECT
2025
if dialect == "mysql":
@@ -27,7 +32,12 @@ async def test_multi_when(self):
2732
category = Case(
2833
When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default"
2934
)
30-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
35+
sql = (
36+
IntFields.all()
37+
.annotate(category=category)
38+
.values("intnum", "category")
39+
.sql(params_inline=True)
40+
)
3141

3242
dialect = self.db.schema_generator.DIALECT
3343
if dialect == "mysql":
@@ -38,7 +48,12 @@ async def test_multi_when(self):
3848

3949
async def test_q_object_when(self):
4050
category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default")
41-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
51+
sql = (
52+
IntFields.all()
53+
.annotate(category=category)
54+
.values("intnum", "category")
55+
.sql(params_inline=True)
56+
)
4257

4358
dialect = self.db.schema_generator.DIALECT
4459
if dialect == "mysql":
@@ -49,7 +64,12 @@ async def test_q_object_when(self):
4964

5065
async def test_F_then(self):
5166
category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default")
52-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
67+
sql = (
68+
IntFields.all()
69+
.annotate(category=category)
70+
.values("intnum", "category")
71+
.sql(params_inline=True)
72+
)
5373

5474
dialect = self.db.schema_generator.DIALECT
5575
if dialect == "mysql":
@@ -61,7 +81,12 @@ async def test_F_then(self):
6181
async def test_AE_then(self):
6282
# AE: ArithmeticExpression
6383
category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default")
64-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
84+
sql = (
85+
IntFields.all()
86+
.annotate(category=category)
87+
.values("intnum", "category")
88+
.sql(params_inline=True)
89+
)
6590

6691
dialect = self.db.schema_generator.DIALECT
6792
if dialect == "mysql":
@@ -72,7 +97,12 @@ async def test_AE_then(self):
7297

7398
async def test_func_then(self):
7499
category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default")
75-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
100+
sql = (
101+
IntFields.all()
102+
.annotate(category=category)
103+
.values("intnum", "category")
104+
.sql(params_inline=True)
105+
)
76106

77107
dialect = self.db.schema_generator.DIALECT
78108
if dialect == "mysql":
@@ -83,7 +113,12 @@ async def test_func_then(self):
83113

84114
async def test_F_default(self):
85115
category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null"))
86-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
116+
sql = (
117+
IntFields.all()
118+
.annotate(category=category)
119+
.values("intnum", "category")
120+
.sql(params_inline=True)
121+
)
87122

88123
dialect = self.db.schema_generator.DIALECT
89124
if dialect == "mysql":
@@ -95,7 +130,12 @@ async def test_F_default(self):
95130
async def test_AE_default(self):
96131
# AE: ArithmeticExpression
97132
category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1)
98-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
133+
sql = (
134+
IntFields.all()
135+
.annotate(category=category)
136+
.values("intnum", "category")
137+
.sql(params_inline=True)
138+
)
99139

100140
dialect = self.db.schema_generator.DIALECT
101141
if dialect == "mysql":
@@ -106,7 +146,12 @@ async def test_AE_default(self):
106146

107147
async def test_func_default(self):
108148
category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10))
109-
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
149+
sql = (
150+
IntFields.all()
151+
.annotate(category=category)
152+
.values("intnum", "category")
153+
.sql(params_inline=True)
154+
)
110155

111156
dialect = self.db.schema_generator.DIALECT
112157
if dialect == "mysql":
@@ -124,7 +169,7 @@ async def test_case_when_in_where(self):
124169
.annotate(category=category)
125170
.filter(category__in=["big", "small"])
126171
.values("intnum")
127-
.sql()
172+
.sql(params_inline=True)
128173
)
129174
dialect = self.db.schema_generator.DIALECT
130175
if dialect == "mysql":
@@ -139,7 +184,7 @@ async def test_annotation_in_when_annotation(self):
139184
.annotate(intnum_plus_1=F("intnum") + 1)
140185
.annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False))
141186
.values("id", "intnum", "intnum_plus_1", "bigger_than_10")
142-
.sql()
187+
.sql(params_inline=True)
143188
)
144189

145190
dialect = self.db.schema_generator.DIALECT
@@ -155,7 +200,7 @@ async def test_func_annotation_in_when_annotation(self):
155200
.annotate(intnum_col=Coalesce("intnum", 0))
156201
.annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False))
157202
.values("id", "intnum_col", "is_zero")
158-
.sql()
203+
.sql(params_inline=True)
159204
)
160205

161206
dialect = self.db.schema_generator.DIALECT
@@ -172,7 +217,7 @@ async def test_case_when_in_group_by(self):
172217
.annotate(count=Count("id"))
173218
.group_by("is_zero")
174219
.values("is_zero", "count")
175-
.sql()
220+
.sql(params_inline=True)
176221
)
177222

178223
dialect = self.db.schema_generator.DIALECT
@@ -188,4 +233,4 @@ async def test_unknown_field_in_when_annotation(self):
188233
with self.assertRaisesRegex(FieldError, "Unknown filter param 'unknown'.+"):
189234
IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate(
190235
is_zero=Case(When(Q(unknown=0), then="1"), default="2")
191-
).sql()
236+
).sql(params_inline=True)

tests/test_filters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ async def test_between_and(self):
268268
[Decimal("1.2345")],
269269
)
270270

271+
async def test_in(self):
272+
self.assertEqual(
273+
await DecimalFields.filter(
274+
decimal__in=[Decimal("1.2345"), Decimal("1000")]
275+
).values_list("decimal", flat=True),
276+
[Decimal("1.2345")],
277+
)
278+
271279

272280
class TestCharFkFieldFilters(test.TestCase):
273281
async def asyncSetUp(self):

tests/test_fuzz.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from tests.testmodels import CharFields
22
from tortoise.contrib import test
33
from tortoise.contrib.test.condition import NotEQ
4+
from tortoise.functions import Upper
45

56
DODGY_STRINGS = [
67
"a/",
@@ -9,6 +10,11 @@
910
"a\\x39",
1011
"a'",
1112
'"',
13+
'""',
14+
"'",
15+
"''",
16+
"\\_",
17+
"\\\\_",
1218
"‘a",
1319
"a’",
1420
"‘a’",
@@ -134,3 +140,12 @@ async def test_char_fuzz(self):
134140
)
135141
self.assertEqual(obj1.pk, obj5.pk)
136142
self.assertEqual(char, obj5.char)
143+
144+
# Filter by a function
145+
obj6 = (
146+
await CharFields.annotate(upper_char=Upper("char"))
147+
.filter(id=obj1.pk, upper_char=Upper("char"))
148+
.first()
149+
)
150+
self.assertEqual(obj1.pk, obj6.pk)
151+
self.assertEqual(char, obj6.char)

tests/test_model_methods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,14 @@ async def test_index_access(self):
296296

297297
async def test_index_badval(self):
298298
with self.assertRaises(ObjectDoesNotExistError) as cm:
299-
await self.cls[100000]
299+
await self.cls[32767]
300300
the_exception = cm.exception
301301
# For compatibility reasons this should be an instance of KeyError
302302
self.assertIsInstance(the_exception, KeyError)
303303
self.assertIs(the_exception.model, self.cls)
304304
self.assertEqual(the_exception.pk_name, "id")
305-
self.assertEqual(the_exception.pk_val, 100000)
306-
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=100000")
305+
self.assertEqual(the_exception.pk_val, 32767)
306+
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767")
307307

308308
async def test_index_badtype(self):
309309
with self.assertRaises(ObjectDoesNotExistError) as cm:

0 commit comments

Comments
 (0)