Skip to content

Commit 5aae984

Browse files
authored
Implement __contains, __contained_by, __overlap and __len for ArrayField (#1877)
* Implement __contains filter for ArrayField * Implement __contained_by filter for ArrayField * Implement __overlap filter for ArrayField * Implement __len filter for ArrayField
1 parent 0898c2b commit 5aae984

File tree

9 files changed

+294
-49
lines changed

9 files changed

+294
-49
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ Changelog
1111

1212
0.24.1 (unreleased)
1313
------
14+
Added
15+
^^^^^
16+
- Implement __contains, __contained_by, __overlap and __len for ArrayField (#1877)
17+
1418
Fixed
1519
^^^^^
1620
- Fix update pk field raises unfriendly error (#1873)
17-
- Fixed asyncio "no current event loop" deprecation warning by replacing `asyncio.get_event_loop()` with modern event loop handling using `get_running_loop()` with fallback to `new_event_loop()` (#1865)
1821

1922
Changed
2023
^^^^^^^
21-
- add benchmarks for `get_for_dialect` (#1862)
2224

2325
0.24.0
2426
------

docs/query.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ PostgreSQL and SQLite also support ``iposix_regex``, which makes case insensive
285285
obj = await DemoModel.filter(demo_text__iposix_regex="^hello world$").first()
286286
287287
288-
In PostgreSQL, ``filter`` supports additional lookup types:
288+
With PostgreSQL, for ``JSONField``, ``filter`` supports additional lookup types:
289289

290290
- ``in`` - ``await JSONModel.filter(data__filter={"breed__in": ["labrador", "poodle"]}).first()``
291291
- ``not_in``
@@ -301,6 +301,13 @@ In PostgreSQL, ``filter`` supports additional lookup types:
301301
- ``istartswith``
302302
- ``iendswith``
303303

304+
With PostgreSQL, ``ArrayField`` can be used with the following lookup types:
305+
306+
- ``contains`` - ``await ArrayFields.filter(array__contains=[1, 2, 3]).first()`` which will use the ``@>`` operator
307+
- ``contained_by`` - will use the ``<@`` operator
308+
- ``overlap`` - will use the ``&&`` operator
309+
- ``len`` - will use the ``array_length`` function, e.g. ``await ArrayFields.filter(array__len=3).first()``
310+
304311

305312
Complex prefetch
306313
================

tests/fields/test_array.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,105 @@ async def test_values_list(self):
4242
obj0 = await testmodels.ArrayFields.create(array=[0])
4343
values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True)
4444
self.assertEqual(values, [0])
45+
46+
async def test_eq_filter(self):
47+
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
48+
obj2 = await testmodels.ArrayFields.create(array=[1, 2])
49+
50+
found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first()
51+
self.assertEqual(found, obj1)
52+
53+
found = await testmodels.ArrayFields.filter(array=[1, 2]).first()
54+
self.assertEqual(found, obj2)
55+
56+
async def test_not_filter(self):
57+
await testmodels.ArrayFields.create(array=[1, 2, 3])
58+
obj2 = await testmodels.ArrayFields.create(array=[1, 2])
59+
60+
found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first()
61+
self.assertEqual(found, obj2)
62+
63+
async def test_contains_ints(self):
64+
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
65+
obj2 = await testmodels.ArrayFields.create(array=[2, 3])
66+
await testmodels.ArrayFields.create(array=[4, 5, 6])
67+
68+
found = await testmodels.ArrayFields.filter(array__contains=[2])
69+
self.assertEqual(found, [obj1, obj2])
70+
71+
found = await testmodels.ArrayFields.filter(array__contains=[10])
72+
self.assertEqual(found, [])
73+
74+
async def test_contains_smallints(self):
75+
obj1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3])
76+
77+
found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first()
78+
self.assertEqual(found, obj1)
79+
80+
async def test_contains_strs(self):
81+
obj1 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[])
82+
83+
found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c"])
84+
self.assertEqual(found, [obj1])
85+
86+
found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b"])
87+
self.assertEqual(found, [obj1])
88+
89+
found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c", "d"])
90+
self.assertEqual(found, [])
91+
92+
async def test_contained_by_ints(self):
93+
obj1 = await testmodels.ArrayFields.create(array=[1])
94+
obj2 = await testmodels.ArrayFields.create(array=[1, 2])
95+
obj3 = await testmodels.ArrayFields.create(array=[1, 2, 3])
96+
97+
found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2, 3])
98+
self.assertEqual(found, [obj1, obj2, obj3])
99+
100+
found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2])
101+
self.assertEqual(found, [obj1, obj2])
102+
103+
found = await testmodels.ArrayFields.filter(array__contained_by=[1])
104+
self.assertEqual(found, [obj1])
105+
106+
async def test_contained_by_strs(self):
107+
obj1 = await testmodels.ArrayFields.create(array_str=["a"], array=[])
108+
obj2 = await testmodels.ArrayFields.create(array_str=["a", "b"], array=[])
109+
obj3 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[])
110+
111+
found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b", "c", "d"])
112+
self.assertEqual(found, [obj1, obj2, obj3])
113+
114+
found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b"])
115+
self.assertEqual(found, [obj1, obj2])
116+
117+
found = await testmodels.ArrayFields.filter(array_str__contained_by=["x", "y", "z"])
118+
self.assertEqual(found, [])
119+
120+
async def test_overlap_ints(self):
121+
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
122+
obj2 = await testmodels.ArrayFields.create(array=[2, 3, 4])
123+
obj3 = await testmodels.ArrayFields.create(array=[3, 4, 5])
124+
125+
found = await testmodels.ArrayFields.filter(array__overlap=[1, 2])
126+
self.assertEqual(found, [obj1, obj2])
127+
128+
found = await testmodels.ArrayFields.filter(array__overlap=[4])
129+
self.assertEqual(found, [obj2, obj3])
130+
131+
found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5])
132+
self.assertEqual(found, [obj1, obj2, obj3])
133+
134+
async def test_array_length(self):
135+
await testmodels.ArrayFields.create(array=[1, 2, 3])
136+
await testmodels.ArrayFields.create(array=[1])
137+
await testmodels.ArrayFields.create(array=[1, 2])
138+
139+
found = await testmodels.ArrayFields.filter(array__len=3).values_list("array", flat=True)
140+
self.assertEqual(list(found), [[1, 2, 3]])
141+
142+
found = await testmodels.ArrayFields.filter(array__len=1).values_list("array", flat=True)
143+
self.assertEqual(list(found), [[1]])
144+
145+
found = await testmodels.ArrayFields.filter(array__len=0).values_list("array", flat=True)
146+
self.assertEqual(list(found), [])

tests/testmodels_postgres.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ class ArrayFields(Model):
66
id = fields.IntField(primary_key=True)
77
array = ArrayField()
88
array_null = ArrayField(null=True)
9+
array_str = ArrayField(element_type="varchar(1)", null=True)
10+
array_smallint = ArrayField(element_type="smallint", null=True)

tortoise/backends/base_postgres/executor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
from tortoise import Model
99
from tortoise.backends.base.executor import BaseExecutor
10+
from tortoise.contrib.postgres.array_functions import (
11+
postgres_array_contained_by,
12+
postgres_array_contains,
13+
postgres_array_length,
14+
postgres_array_overlap,
15+
)
1016
from tortoise.contrib.postgres.json_functions import (
1117
postgres_json_contained_by,
1218
postgres_json_contains,
@@ -18,6 +24,10 @@
1824
)
1925
from tortoise.contrib.postgres.search import SearchCriterion
2026
from tortoise.filters import (
27+
array_contained_by,
28+
array_contains,
29+
array_length,
30+
array_overlap,
2131
insensitive_posix_regex,
2232
json_contained_by,
2333
json_contains,
@@ -36,11 +46,15 @@ class BasePostgresExecutor(BaseExecutor):
3646
DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID}
3747
FILTER_FUNC_OVERRIDE = {
3848
search: postgres_search,
49+
array_contains: postgres_array_contains,
50+
array_contained_by: postgres_array_contained_by,
51+
array_overlap: postgres_array_overlap,
3952
json_contains: postgres_json_contains,
4053
json_contained_by: postgres_json_contained_by,
4154
json_filter: postgres_json_filter,
4255
posix_regex: postgres_posix_regex,
4356
insensitive_posix_regex: postgres_insensitive_posix_regex,
57+
array_length: postgres_array_length,
4458
}
4559

4660
def _prepare_insert_statement(
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from enum import Enum
2+
3+
from pypika_tortoise.terms import BasicCriterion, Criterion, Function, Term
4+
5+
6+
class PostgresArrayOperators(str, Enum):
7+
CONTAINS = "@>"
8+
CONTAINED_BY = "<@"
9+
OVERLAP = "&&"
10+
11+
12+
# The value in the functions below is casted to the exact type of the field with value_encoder
13+
# to avoid issues with psycopg that tries to use the smallest possible type which can lead to errors,
14+
# e.g. {1,2} will be casted to smallint[] instead of integer[].
15+
16+
17+
def postgres_array_contains(field: Term, value: Term) -> Criterion:
18+
return BasicCriterion(PostgresArrayOperators.CONTAINS, field, value)
19+
20+
21+
def postgres_array_contained_by(field: Term, value: Term) -> Criterion:
22+
return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, value)
23+
24+
25+
def postgres_array_overlap(field: Term, value: Term) -> Criterion:
26+
return BasicCriterion(PostgresArrayOperators.OVERLAP, field, value)
27+
28+
29+
def postgres_array_length(field: Term, value: int) -> Criterion:
30+
"""Returns a criterion that checks if array length equals the given value"""
31+
return Function("array_length", field, 1).eq(value)

tortoise/expressions.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -355,33 +355,30 @@ def _process_filter_kwarg(
355355
join = None
356356

357357
if value is None and f"{key}__isnull" in model._meta.filters:
358-
param = model._meta.get_filter(f"{key}__isnull")
358+
filter_info = model._meta.get_filter(f"{key}__isnull")
359359
value = True
360360
else:
361-
param = model._meta.get_filter(key)
361+
filter_info = model._meta.get_filter(key)
362362

363-
pk_db_field = model._meta.db_pk_column
364-
if param.get("table"):
363+
if "table" in filter_info:
364+
# join the table
365365
join = (
366-
param["table"],
367-
table[pk_db_field] == param["table"][param["backward_key"]],
366+
filter_info["table"],
367+
table[model._meta.db_pk_column]
368+
== filter_info["table"][filter_info["backward_key"]],
368369
)
369-
if param.get("value_encoder"):
370-
value = param["value_encoder"](value, model)
371-
op = param["operator"]
372-
criterion = op(param["table"][param["field"]], value)
373-
else:
374-
if isinstance(value, Term):
375-
encoded_value = value
376-
else:
377-
field_object = model._meta.fields_map[param["field"]]
378-
encoded_value = (
379-
param["value_encoder"](value, model, field_object)
380-
if param.get("value_encoder")
381-
else field_object.to_db_value(value, model)
382-
)
383-
op = param["operator"]
384-
criterion = op(table[param["source_field"]], encoded_value)
370+
if "value_encoder" in filter_info:
371+
value = filter_info["value_encoder"](value, model)
372+
table = filter_info["table"]
373+
elif not isinstance(value, Term):
374+
field_object = model._meta.fields_map[filter_info["field"]]
375+
value = (
376+
filter_info["value_encoder"](value, model, field_object)
377+
if "value_encoder" in filter_info
378+
else field_object.to_db_value(value, model)
379+
)
380+
op = filter_info["operator"]
381+
criterion = op(table[filter_info.get("source_field", filter_info["field"])], value)
385382
return criterion, join
386383

387384
def _resolve_regular_kwarg(

tortoise/fields/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def _get_dialects(self) -> dict[str, dict]:
317317

318318
return ret
319319

320+
def get_db_field_type(self) -> str:
321+
"""
322+
Returns the DB field type for this field for the current dialect.
323+
"""
324+
dialect = self.model._meta.db.capabilities.dialect
325+
return self.get_for_dialect(dialect, "SQL_TYPE")
326+
320327
def get_db_field_types(self) -> Optional[dict[str, str]]:
321328
"""
322329
Returns the DB types for this field.

0 commit comments

Comments
 (0)