diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0f2562186..8e76212f7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,14 +11,16 @@ Changelog 0.24.1 (unreleased) ------ +Added +^^^^^ +- Implement __contains, __contained_by, __overlap and __len for ArrayField (#1877) + Fixed ^^^^^ - Fix update pk field raises unfriendly error (#1873) -- Fixed asyncio "no current event loop" deprecation warning by replacing `asyncio.get_event_loop()` with modern event loop handling using `get_running_loop()` with fallback to `new_event_loop()` (#1865) Changed ^^^^^^^ -- add benchmarks for `get_for_dialect` (#1862) 0.24.0 ------ diff --git a/docs/query.rst b/docs/query.rst index 75cc7975b..a25e84df0 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -285,7 +285,7 @@ PostgreSQL and SQLite also support ``iposix_regex``, which makes case insensive obj = await DemoModel.filter(demo_text__iposix_regex="^hello world$").first() -In PostgreSQL, ``filter`` supports additional lookup types: +With PostgreSQL, for ``JSONField``, ``filter`` supports additional lookup types: - ``in`` - ``await JSONModel.filter(data__filter={"breed__in": ["labrador", "poodle"]}).first()`` - ``not_in`` @@ -301,6 +301,13 @@ In PostgreSQL, ``filter`` supports additional lookup types: - ``istartswith`` - ``iendswith`` +With PostgreSQL, ``ArrayField`` can be used with the following lookup types: + +- ``contains`` - ``await ArrayFields.filter(array__contains=[1, 2, 3]).first()`` which will use the ``@>`` operator +- ``contained_by`` - will use the ``<@`` operator +- ``overlap`` - will use the ``&&`` operator +- ``len`` - will use the ``array_length`` function, e.g. ``await ArrayFields.filter(array__len=3).first()`` + Complex prefetch ================ diff --git a/tests/fields/test_array.py b/tests/fields/test_array.py index 7fabf90e0..2bf192dd2 100644 --- a/tests/fields/test_array.py +++ b/tests/fields/test_array.py @@ -42,3 +42,105 @@ async def test_values_list(self): obj0 = await testmodels.ArrayFields.create(array=[0]) values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True) self.assertEqual(values, [0]) + + async def test_eq_filter(self): + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) + + found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first() + self.assertEqual(found, obj1) + + found = await testmodels.ArrayFields.filter(array=[1, 2]).first() + self.assertEqual(found, obj2) + + async def test_not_filter(self): + await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) + + found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first() + self.assertEqual(found, obj2) + + async def test_contains_ints(self): + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[2, 3]) + await testmodels.ArrayFields.create(array=[4, 5, 6]) + + found = await testmodels.ArrayFields.filter(array__contains=[2]) + self.assertEqual(found, [obj1, obj2]) + + found = await testmodels.ArrayFields.filter(array__contains=[10]) + self.assertEqual(found, []) + + async def test_contains_smallints(self): + obj1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3]) + + found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first() + self.assertEqual(found, obj1) + + async def test_contains_strs(self): + obj1 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) + + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c"]) + self.assertEqual(found, [obj1]) + + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b"]) + self.assertEqual(found, [obj1]) + + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c", "d"]) + self.assertEqual(found, []) + + async def test_contained_by_ints(self): + obj1 = await testmodels.ArrayFields.create(array=[1]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) + obj3 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + + found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2, 3]) + self.assertEqual(found, [obj1, obj2, obj3]) + + found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2]) + self.assertEqual(found, [obj1, obj2]) + + found = await testmodels.ArrayFields.filter(array__contained_by=[1]) + self.assertEqual(found, [obj1]) + + async def test_contained_by_strs(self): + obj1 = await testmodels.ArrayFields.create(array_str=["a"], array=[]) + obj2 = await testmodels.ArrayFields.create(array_str=["a", "b"], array=[]) + obj3 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) + + found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b", "c", "d"]) + self.assertEqual(found, [obj1, obj2, obj3]) + + found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b"]) + self.assertEqual(found, [obj1, obj2]) + + found = await testmodels.ArrayFields.filter(array_str__contained_by=["x", "y", "z"]) + self.assertEqual(found, []) + + async def test_overlap_ints(self): + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[2, 3, 4]) + obj3 = await testmodels.ArrayFields.create(array=[3, 4, 5]) + + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2]) + self.assertEqual(found, [obj1, obj2]) + + found = await testmodels.ArrayFields.filter(array__overlap=[4]) + self.assertEqual(found, [obj2, obj3]) + + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5]) + self.assertEqual(found, [obj1, obj2, obj3]) + + async def test_array_length(self): + await testmodels.ArrayFields.create(array=[1, 2, 3]) + await testmodels.ArrayFields.create(array=[1]) + await testmodels.ArrayFields.create(array=[1, 2]) + + found = await testmodels.ArrayFields.filter(array__len=3).values_list("array", flat=True) + self.assertEqual(list(found), [[1, 2, 3]]) + + found = await testmodels.ArrayFields.filter(array__len=1).values_list("array", flat=True) + self.assertEqual(list(found), [[1]]) + + found = await testmodels.ArrayFields.filter(array__len=0).values_list("array", flat=True) + self.assertEqual(list(found), []) diff --git a/tests/testmodels_postgres.py b/tests/testmodels_postgres.py index 97fbce976..d2c460f3c 100644 --- a/tests/testmodels_postgres.py +++ b/tests/testmodels_postgres.py @@ -6,3 +6,5 @@ class ArrayFields(Model): id = fields.IntField(primary_key=True) array = ArrayField() array_null = ArrayField(null=True) + array_str = ArrayField(element_type="varchar(1)", null=True) + array_smallint = ArrayField(element_type="smallint", null=True) diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index cc6dafda7..e51632e30 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -7,6 +7,12 @@ from tortoise import Model from tortoise.backends.base.executor import BaseExecutor +from tortoise.contrib.postgres.array_functions import ( + postgres_array_contained_by, + postgres_array_contains, + postgres_array_length, + postgres_array_overlap, +) from tortoise.contrib.postgres.json_functions import ( postgres_json_contained_by, postgres_json_contains, @@ -18,6 +24,10 @@ ) from tortoise.contrib.postgres.search import SearchCriterion from tortoise.filters import ( + array_contained_by, + array_contains, + array_length, + array_overlap, insensitive_posix_regex, json_contained_by, json_contains, @@ -36,11 +46,15 @@ class BasePostgresExecutor(BaseExecutor): DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID} FILTER_FUNC_OVERRIDE = { search: postgres_search, + array_contains: postgres_array_contains, + array_contained_by: postgres_array_contained_by, + array_overlap: postgres_array_overlap, json_contains: postgres_json_contains, json_contained_by: postgres_json_contained_by, json_filter: postgres_json_filter, posix_regex: postgres_posix_regex, insensitive_posix_regex: postgres_insensitive_posix_regex, + array_length: postgres_array_length, } def _prepare_insert_statement( diff --git a/tortoise/contrib/postgres/array_functions.py b/tortoise/contrib/postgres/array_functions.py new file mode 100644 index 000000000..a57a27274 --- /dev/null +++ b/tortoise/contrib/postgres/array_functions.py @@ -0,0 +1,31 @@ +from enum import Enum + +from pypika_tortoise.terms import BasicCriterion, Criterion, Function, Term + + +class PostgresArrayOperators(str, Enum): + CONTAINS = "@>" + CONTAINED_BY = "<@" + OVERLAP = "&&" + + +# The value in the functions below is casted to the exact type of the field with value_encoder +# to avoid issues with psycopg that tries to use the smallest possible type which can lead to errors, +# e.g. {1,2} will be casted to smallint[] instead of integer[]. + + +def postgres_array_contains(field: Term, value: Term) -> Criterion: + return BasicCriterion(PostgresArrayOperators.CONTAINS, field, value) + + +def postgres_array_contained_by(field: Term, value: Term) -> Criterion: + return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, value) + + +def postgres_array_overlap(field: Term, value: Term) -> Criterion: + return BasicCriterion(PostgresArrayOperators.OVERLAP, field, value) + + +def postgres_array_length(field: Term, value: int) -> Criterion: + """Returns a criterion that checks if array length equals the given value""" + return Function("array_length", field, 1).eq(value) diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 9e0cc05be..f91d1c124 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -355,33 +355,30 @@ def _process_filter_kwarg( join = None if value is None and f"{key}__isnull" in model._meta.filters: - param = model._meta.get_filter(f"{key}__isnull") + filter_info = model._meta.get_filter(f"{key}__isnull") value = True else: - param = model._meta.get_filter(key) + filter_info = model._meta.get_filter(key) - pk_db_field = model._meta.db_pk_column - if param.get("table"): + if "table" in filter_info: + # join the table join = ( - param["table"], - table[pk_db_field] == param["table"][param["backward_key"]], + filter_info["table"], + table[model._meta.db_pk_column] + == filter_info["table"][filter_info["backward_key"]], ) - if param.get("value_encoder"): - value = param["value_encoder"](value, model) - op = param["operator"] - criterion = op(param["table"][param["field"]], value) - else: - if isinstance(value, Term): - encoded_value = value - else: - field_object = model._meta.fields_map[param["field"]] - encoded_value = ( - param["value_encoder"](value, model, field_object) - if param.get("value_encoder") - else field_object.to_db_value(value, model) - ) - op = param["operator"] - criterion = op(table[param["source_field"]], encoded_value) + if "value_encoder" in filter_info: + value = filter_info["value_encoder"](value, model) + table = filter_info["table"] + elif not isinstance(value, Term): + field_object = model._meta.fields_map[filter_info["field"]] + value = ( + filter_info["value_encoder"](value, model, field_object) + if "value_encoder" in filter_info + else field_object.to_db_value(value, model) + ) + op = filter_info["operator"] + criterion = op(table[filter_info.get("source_field", filter_info["field"])], value) return criterion, join def _resolve_regular_kwarg( diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 062864a29..c7626e055 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -317,6 +317,13 @@ def _get_dialects(self) -> dict[str, dict]: return ret + def get_db_field_type(self) -> str: + """ + Returns the DB field type for this field for the current dialect. + """ + dialect = self.model._meta.db.capabilities.dialect + return self.get_for_dialect(dialect, "SQL_TYPE") + def get_db_field_types(self) -> Optional[dict[str, str]]: """ Returns the DB types for this field. diff --git a/tortoise/filters.py b/tortoise/filters.py index e9535d203..b97a81212 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -3,12 +3,13 @@ import operator from collections.abc import Callable, Iterable from functools import partial -from typing import TYPE_CHECKING, Any, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Optional, Sequence, TypedDict, Union from pypika_tortoise import SqlContext, Table from pypika_tortoise.enums import DatePart, Matching, SqlTypes from pypika_tortoise.functions import Cast, Extract, Upper from pypika_tortoise.terms import ( + Array, BasicCriterion, Criterion, Equality, @@ -17,6 +18,7 @@ ) from typing_extensions import NotRequired +from tortoise.contrib.postgres.fields import ArrayField from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance @@ -72,10 +74,21 @@ def string_encoder(value: Any, instance: "Model", field: Field) -> str: return str(value) +def int_encoder(value: Any, instance: "Model", field: Field) -> int: + return int(value) + + def json_encoder(value: Any, instance: "Model", field: Field) -> dict: return value +def array_encoder(value: Union[Any, Sequence[Any]], instance: "Model", field: Field) -> Any: + # Casting to the exact type of the field to avoid issues with psycopg that tries + # to use the smallest possible type which can lead to errors, + # e.g. {1,2} will be casted to smallint[] instead of integer[]. + return Cast(Array(*value), field.get_db_field_type()) + + ############################################################################## # Operators # Should be type: (field: Term, value: Any) -> Criterion: @@ -213,19 +226,32 @@ def extract_microsecond_equal(field: Term, value: int) -> Criterion: return Extract(DatePart.microsecond, field).eq(value) -def json_contains(field: Term, value: str) -> Criterion: # type:ignore[empty-body] - # will be override in each executor - pass +def json_contains(field: Term, value: str) -> Criterion: + raise NotImplementedError("must be overridden in each executor") -def json_contained_by(field: Term, value: str) -> Criterion: # type:ignore[empty-body] - # will be override in each executor - pass +def json_contained_by(field: Term, value: str) -> Criterion: + raise NotImplementedError("must be overridden in each executor") -def json_filter(field: Term, value: dict) -> Criterion: # type:ignore[empty-body] - # will be override in each executor - pass +def json_filter(field: Term, value: dict) -> Criterion: + raise NotImplementedError("must be overridden in each xecutor") + + +def array_contains(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: + raise NotImplementedError("must be overridden in each executor") + + +def array_contained_by(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: + raise NotImplementedError("must be overridden in each executor") + + +def array_overlap(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: + raise NotImplementedError("must be overridden in each executor") + + +def array_length(field: Term, value: int) -> Criterion: + raise NotImplementedError("must be overridden in each executor") ############################################################################## @@ -325,42 +351,41 @@ def get_backward_fk_filters( def get_json_filter(field_name: str, source_field: str) -> dict[str, FilterInfoDict]: - actual_field_name = field_name return { field_name: { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": operator.eq, }, f"{field_name}__not": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": not_equal, }, f"{field_name}__isnull": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": is_null, "value_encoder": bool_encoder, }, f"{field_name}__not_isnull": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": not_null, "value_encoder": bool_encoder, }, f"{field_name}__contains": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": json_contains, }, f"{field_name}__contained_by": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": json_contained_by, }, f"{field_name}__filter": { - "field": actual_field_name, + "field": field_name, "source_field": source_field, "operator": json_filter, "value_encoder": json_encoder, @@ -381,15 +406,73 @@ def get_json_filter_operator( return key_parts, filter_value, operator_ +def get_array_filter( + field_name: str, source_field: str, field: ArrayField +) -> dict[str, FilterInfoDict]: + return { + field_name: { + "field": field_name, + "source_field": source_field, + "operator": operator.eq, + "value_encoder": array_encoder, + }, + f"{field_name}__not": { + "field": field_name, + "source_field": source_field, + "operator": not_equal, + "value_encoder": array_encoder, + }, + f"{field_name}__isnull": { + "field": field_name, + "source_field": source_field, + "operator": is_null, + "value_encoder": bool_encoder, + }, + f"{field_name}__not_isnull": { + "field": field_name, + "source_field": source_field, + "operator": not_null, + "value_encoder": bool_encoder, + }, + f"{field_name}__contains": { + "field": field_name, + "source_field": source_field, + "operator": array_contains, + "value_encoder": array_encoder, + }, + f"{field_name}__contained_by": { + "field": field_name, + "source_field": source_field, + "operator": array_contained_by, + "value_encoder": array_encoder, + }, + f"{field_name}__overlap": { + "field": field_name, + "source_field": source_field, + "operator": array_overlap, + "value_encoder": array_encoder, + }, + f"{field_name}__len": { + "field": field_name, + "source_field": source_field, + "operator": array_length, + "value_encoder": int_encoder, + }, + } + + def get_filters_for_field( field_name: str, field: Optional[Field], source_field: str ) -> dict[str, FilterInfoDict]: - if isinstance(field, ManyToManyFieldInstance): - return get_m2m_filters(field_name, field) - if isinstance(field, BackwardFKRelation): - return get_backward_fk_filters(field_name, field) - if isinstance(field, JSONField): - return get_json_filter(field_name, source_field) + if field is not None: + if isinstance(field, ManyToManyFieldInstance): + return get_m2m_filters(field_name, field) + if isinstance(field, BackwardFKRelation): + return get_backward_fk_filters(field_name, field) + if isinstance(field, JSONField): + return get_json_filter(field_name, source_field) + if isinstance(field, ArrayField): + return get_array_filter(field_name, source_field, field) actual_field_name = field_name if field_name == "pk" and field: