Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ Changelog

0.24.1 (unreleased)
------
Added
^^^^^
- Implement __contains, __contained_by, __overlap and __len for ArrayField (#1877)

Fixed
^^^^^
- Fix update pk field raises unfriendly error (#1873)
- Fixed asyncio "no current event loop" deprecation warning by replacing `asyncio.get_event_loop()` with modern event loop handling using `get_running_loop()` with fallback to `new_event_loop()` (#1865)

Changed
^^^^^^^
- add benchmarks for `get_for_dialect` (#1862)
Copy link
Contributor Author

@henadzit henadzit Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the internal notes that important only for the development process, not for the package users.


0.24.0
------
Expand Down
9 changes: 8 additions & 1 deletion docs/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ PostgreSQL and SQLite also support ``iposix_regex``, which makes case insensive
obj = await DemoModel.filter(demo_text__iposix_regex="^hello world$").first()


In PostgreSQL, ``filter`` supports additional lookup types:
With PostgreSQL, for ``JSONField``, ``filter`` supports additional lookup types:

- ``in`` - ``await JSONModel.filter(data__filter={"breed__in": ["labrador", "poodle"]}).first()``
- ``not_in``
Expand All @@ -301,6 +301,13 @@ In PostgreSQL, ``filter`` supports additional lookup types:
- ``istartswith``
- ``iendswith``

With PostgreSQL, ``ArrayField`` can be used with the following lookup types:

- ``contains`` - ``await ArrayFields.filter(array__contains=[1, 2, 3]).first()`` which will use the ``@>`` operator
- ``contained_by`` - will use the ``<@`` operator
- ``overlap`` - will use the ``&&`` operator
- ``len`` - will use the ``array_length`` function, e.g. ``await ArrayFields.filter(array__len=3).first()``


Complex prefetch
================
Expand Down
102 changes: 102 additions & 0 deletions tests/fields/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,105 @@ async def test_values_list(self):
obj0 = await testmodels.ArrayFields.create(array=[0])
values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True)
self.assertEqual(values, [0])

async def test_eq_filter(self):
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
obj2 = await testmodels.ArrayFields.create(array=[1, 2])

found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first()
self.assertEqual(found, obj1)

found = await testmodels.ArrayFields.filter(array=[1, 2]).first()
self.assertEqual(found, obj2)

async def test_not_filter(self):
await testmodels.ArrayFields.create(array=[1, 2, 3])
obj2 = await testmodels.ArrayFields.create(array=[1, 2])

found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first()
self.assertEqual(found, obj2)

async def test_contains_ints(self):
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
obj2 = await testmodels.ArrayFields.create(array=[2, 3])
await testmodels.ArrayFields.create(array=[4, 5, 6])

found = await testmodels.ArrayFields.filter(array__contains=[2])
self.assertEqual(found, [obj1, obj2])

found = await testmodels.ArrayFields.filter(array__contains=[10])
self.assertEqual(found, [])

async def test_contains_smallints(self):
obj1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3])

found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first()
self.assertEqual(found, obj1)

async def test_contains_strs(self):
obj1 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[])

found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c"])
self.assertEqual(found, [obj1])

found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b"])
self.assertEqual(found, [obj1])

found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c", "d"])
self.assertEqual(found, [])

async def test_contained_by_ints(self):
obj1 = await testmodels.ArrayFields.create(array=[1])
obj2 = await testmodels.ArrayFields.create(array=[1, 2])
obj3 = await testmodels.ArrayFields.create(array=[1, 2, 3])

found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2, 3])
self.assertEqual(found, [obj1, obj2, obj3])

found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2])
self.assertEqual(found, [obj1, obj2])

found = await testmodels.ArrayFields.filter(array__contained_by=[1])
self.assertEqual(found, [obj1])

async def test_contained_by_strs(self):
obj1 = await testmodels.ArrayFields.create(array_str=["a"], array=[])
obj2 = await testmodels.ArrayFields.create(array_str=["a", "b"], array=[])
obj3 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[])

found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b", "c", "d"])
self.assertEqual(found, [obj1, obj2, obj3])

found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b"])
self.assertEqual(found, [obj1, obj2])

found = await testmodels.ArrayFields.filter(array_str__contained_by=["x", "y", "z"])
self.assertEqual(found, [])

async def test_overlap_ints(self):
obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
obj2 = await testmodels.ArrayFields.create(array=[2, 3, 4])
obj3 = await testmodels.ArrayFields.create(array=[3, 4, 5])

found = await testmodels.ArrayFields.filter(array__overlap=[1, 2])
self.assertEqual(found, [obj1, obj2])

found = await testmodels.ArrayFields.filter(array__overlap=[4])
self.assertEqual(found, [obj2, obj3])

found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5])
self.assertEqual(found, [obj1, obj2, obj3])

async def test_array_length(self):
await testmodels.ArrayFields.create(array=[1, 2, 3])
await testmodels.ArrayFields.create(array=[1])
await testmodels.ArrayFields.create(array=[1, 2])

found = await testmodels.ArrayFields.filter(array__len=3).values_list("array", flat=True)
self.assertEqual(list(found), [[1, 2, 3]])

found = await testmodels.ArrayFields.filter(array__len=1).values_list("array", flat=True)
self.assertEqual(list(found), [[1]])

found = await testmodels.ArrayFields.filter(array__len=0).values_list("array", flat=True)
self.assertEqual(list(found), [])
2 changes: 2 additions & 0 deletions tests/testmodels_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ class ArrayFields(Model):
id = fields.IntField(primary_key=True)
array = ArrayField()
array_null = ArrayField(null=True)
array_str = ArrayField(element_type="varchar(1)", null=True)
array_smallint = ArrayField(element_type="smallint", null=True)
14 changes: 14 additions & 0 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.postgres.array_functions import (
postgres_array_contained_by,
postgres_array_contains,
postgres_array_length,
postgres_array_overlap,
)
from tortoise.contrib.postgres.json_functions import (
postgres_json_contained_by,
postgres_json_contains,
Expand All @@ -18,6 +24,10 @@
)
from tortoise.contrib.postgres.search import SearchCriterion
from tortoise.filters import (
array_contained_by,
array_contains,
array_length,
array_overlap,
insensitive_posix_regex,
json_contained_by,
json_contains,
Expand All @@ -36,11 +46,15 @@ class BasePostgresExecutor(BaseExecutor):
DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID}
FILTER_FUNC_OVERRIDE = {
search: postgres_search,
array_contains: postgres_array_contains,
array_contained_by: postgres_array_contained_by,
array_overlap: postgres_array_overlap,
json_contains: postgres_json_contains,
json_contained_by: postgres_json_contained_by,
json_filter: postgres_json_filter,
posix_regex: postgres_posix_regex,
insensitive_posix_regex: postgres_insensitive_posix_regex,
array_length: postgres_array_length,
}

def _prepare_insert_statement(
Expand Down
31 changes: 31 additions & 0 deletions tortoise/contrib/postgres/array_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from enum import Enum

from pypika_tortoise.terms import BasicCriterion, Criterion, Function, Term


class PostgresArrayOperators(str, Enum):
CONTAINS = "@>"
CONTAINED_BY = "<@"
OVERLAP = "&&"


# The value in the functions below is casted to the exact type of the field with value_encoder
# to avoid issues with psycopg that tries to use the smallest possible type which can lead to errors,
# e.g. {1,2} will be casted to smallint[] instead of integer[].


def postgres_array_contains(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.CONTAINS, field, value)


def postgres_array_contained_by(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, value)


def postgres_array_overlap(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.OVERLAP, field, value)


def postgres_array_length(field: Term, value: int) -> Criterion:
"""Returns a criterion that checks if array length equals the given value"""
return Function("array_length", field, 1).eq(value)
41 changes: 19 additions & 22 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,33 +355,30 @@ def _process_filter_kwarg(
join = None

if value is None and f"{key}__isnull" in model._meta.filters:
param = model._meta.get_filter(f"{key}__isnull")
filter_info = model._meta.get_filter(f"{key}__isnull")
value = True
else:
param = model._meta.get_filter(key)
filter_info = model._meta.get_filter(key)

pk_db_field = model._meta.db_pk_column
if param.get("table"):
if "table" in filter_info:
# join the table
join = (
param["table"],
table[pk_db_field] == param["table"][param["backward_key"]],
filter_info["table"],
table[model._meta.db_pk_column]
== filter_info["table"][filter_info["backward_key"]],
)
if param.get("value_encoder"):
value = param["value_encoder"](value, model)
op = param["operator"]
criterion = op(param["table"][param["field"]], value)
else:
if isinstance(value, Term):
encoded_value = value
else:
field_object = model._meta.fields_map[param["field"]]
encoded_value = (
param["value_encoder"](value, model, field_object)
if param.get("value_encoder")
else field_object.to_db_value(value, model)
)
op = param["operator"]
criterion = op(table[param["source_field"]], encoded_value)
if "value_encoder" in filter_info:
value = filter_info["value_encoder"](value, model)
table = filter_info["table"]
elif not isinstance(value, Term):
field_object = model._meta.fields_map[filter_info["field"]]
value = (
filter_info["value_encoder"](value, model, field_object)
if "value_encoder" in filter_info
else field_object.to_db_value(value, model)
)
op = filter_info["operator"]
criterion = op(table[filter_info.get("source_field", filter_info["field"])], value)
return criterion, join

def _resolve_regular_kwarg(
Expand Down
7 changes: 7 additions & 0 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def _get_dialects(self) -> dict[str, dict]:

return ret

def get_db_field_type(self) -> str:
"""
Returns the DB field type for this field for the current dialect.
"""
dialect = self.model._meta.db.capabilities.dialect
return self.get_for_dialect(dialect, "SQL_TYPE")

def get_db_field_types(self) -> Optional[dict[str, str]]:
"""
Returns the DB types for this field.
Expand Down
Loading