diff --git a/.gitignore b/.gitignore index 5e278232..b16b6998 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,5 @@ tests_sync/ # version files .tool-versions + +.vscode/ diff --git a/aredis_om/__init__.py b/aredis_om/__init__.py index 813e3b04..847b124f 100644 --- a/aredis_om/__init__.py +++ b/aredis_om/__init__.py @@ -16,3 +16,4 @@ RedisModelError, VectorFieldOptions, ) +from .model.types import Coordinates, GeoFilter diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 5a5c75ee..f95d4827 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1,6 +1,5 @@ import abc import dataclasses -import decimal import json import logging import operator @@ -45,6 +44,7 @@ from .encoders import jsonable_encoder from .render_tree import render_tree from .token_escaper import TokenEscaper +from .types import Coordinates, CoordinateType, GeoFilter model_registry = {} @@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum): GEO = "GEO" -# TODO: How to handle Geo fields? DEFAULT_PAGE_SIZE = 1000 @@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]): def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes: field_info: Union[FieldInfo, PydanticFieldInfo] = field + typ = get_outer_type(field_info) + if getattr(field_info, "primary_key", None) is True: return RediSearchFieldTypes.TAG + elif typ in [CoordinateType, Coordinates]: + return RediSearchFieldTypes.GEO elif op is Operators.LIKE: fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined @@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType if not isinstance(field_type, type): field_type = field_type.__origin__ - # TODO: GEO fields container_type = get_origin(field_type) if is_supported_container_type(container_type): @@ -726,6 +728,15 @@ def resolve_value( field_name=field_name, expanded_value=expanded_value ) + elif field_type is RediSearchFieldTypes.GEO: + if not isinstance(value, GeoFilter): + raise QuerySyntaxError( + "You can only use a GeoFilter object with a GEO field." + ) + + if op is Operators.EQ: + result += f"@{field_name}:[{value}]" + return result def resolve_redisearch_pagination(self): @@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = cls.schema_for_type(name, embedded_cls, field_info) elif typ is bool: schema = f"{name} TAG" + elif typ in [CoordinateType, Coordinates]: + schema = f"{name} GEO" elif is_numeric_type(typ): vector_options: Optional[VectorFieldOptions] = getattr( field_info, "vector_options", None @@ -2107,7 +2120,6 @@ def schema_for_type( else typ ) - # TODO: GEO field if is_vector and vector_options: schema = f"{path} AS {index_field_name} {vector_options.schema}" elif parent_is_container_type or parent_is_model_in_container: @@ -2128,6 +2140,8 @@ def schema_for_type( schema += " CASESENSITIVE" elif typ is bool: schema = f"{path} AS {index_field_name} TAG" + elif typ in [CoordinateType, Coordinates]: + schema = f"{path} AS {index_field_name} GEO" elif is_numeric_type(typ): schema = f"{path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): diff --git a/aredis_om/model/types.py b/aredis_om/model/types.py new file mode 100644 index 00000000..3e9029ca --- /dev/null +++ b/aredis_om/model/types.py @@ -0,0 +1,116 @@ +from typing import Annotated, Any, Literal, Tuple, Union + +from pydantic import BeforeValidator, PlainSerializer +from pydantic_extra_types.coordinate import Coordinate + + +RadiusUnit = Literal["m", "km", "mi", "ft"] + + +class GeoFilter: + """ + A geographic filter for searching within a radius of a coordinate point. + + This filter is used with GEO fields to find models within a specified + distance from a given location. + + Args: + longitude: The longitude of the center point (-180 to 180) + latitude: The latitude of the center point (-90 to 90) + radius: The search radius (must be positive) + unit: The unit of measurement ('m', 'km', 'mi', or 'ft') + + Example: + >>> # Find all locations within 10 miles of Portland, OR + >>> filter = GeoFilter( + ... longitude=-122.6765, + ... latitude=45.5231, + ... radius=10, + ... unit="mi" + ... ) + >>> results = await Location.find( + ... Location.coordinates == filter + ... ).all() + """ + + def __init__( + self, longitude: float, latitude: float, radius: float, unit: RadiusUnit + ): + # Validate coordinates + if not -180 <= longitude <= 180: + raise ValueError(f"Longitude must be between -180 and 180, got {longitude}") + if not -90 <= latitude <= 90: + raise ValueError(f"Latitude must be between -90 and 90, got {latitude}") + if radius <= 0: + raise ValueError(f"Radius must be positive, got {radius}") + + self.longitude = longitude + self.latitude = latitude + self.radius = radius + self.unit = unit + + def __str__(self) -> str: + return f"{self.longitude} {self.latitude} {self.radius} {self.unit}" + + @classmethod + def from_coordinates( + cls, coords: Coordinate, radius: float, unit: RadiusUnit + ) -> "GeoFilter": + """ + Create a GeoFilter from a Coordinates object. + + Args: + coords: A Coordinate object with latitude and longitude + radius: The search radius + unit: The unit of measurement + + Returns: + A new GeoFilter instance + """ + return cls(coords.longitude, coords.latitude, radius, unit) + + +CoordinateType = Coordinate + + +def parse_redis(v: Any) -> Union[Tuple[str, str], Any]: + """ + Transform Redis coordinate format to Pydantic coordinate format. + + The pydantic coordinate type expects a string in the format 'latitude,longitude'. + Redis stores coordinates in the format 'longitude,latitude'. + + This validator transforms the input from Redis into the expected format for pydantic. + + Args: + v: The value from Redis (typically a string like "longitude,latitude") + + Returns: + A tuple of (latitude, longitude) strings if input is a coordinate string, + otherwise returns the input unchanged. + + Raises: + ValueError: If the coordinate string format is invalid + """ + if isinstance(v, str): + parts = v.split(",") + + if len(parts) != 2: + raise ValueError( + f"Invalid coordinate format. Expected 'longitude,latitude' but got: {v}" + ) + + return (parts[1], parts[0]) # Swap to (latitude, longitude) + + return v + + +Coordinates = Annotated[ + CoordinateType, + PlainSerializer( + lambda v: f"{v.longitude},{v.latitude}", + return_type=str, + when_used="unless-none", + ), + BeforeValidator(parse_redis), +] diff --git a/pyproject.toml b/pyproject.toml index d89f3264..62599806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "1.0.2-beta" +version = "1.0.3-beta" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -46,6 +46,7 @@ typing-extensions = "^4.4.0" hiredis = ">=2.2.3,<4.0.0" more-itertools = ">=8.14,<11.0" setuptools = ">=70.0" +pydantic-extra-types = "^2.10.5" [tool.poetry.group.dev.dependencies] mypy = "^1.9.0" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index c3b578ac..928986c4 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -13,7 +13,9 @@ import pytest_asyncio from aredis_om import ( + Coordinates, Field, + GeoFilter, HashModel, Migrator, NotFoundError, @@ -1054,3 +1056,113 @@ class Meta: rematerialized = await Model.find(Model.first_name == "Steve").first() assert rematerialized.pk == model.pk + + +@py_test_mark_asyncio +async def test_can_search_on_coordinates(key_prefix, redis): + class Location(HashModel, index=True): + coordinates: Coordinates = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc = Location(coordinates=(latitude, longitude)) + + await loc.save() + + rematerialized: Location = await Location.find( + Location.coordinates + == GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi") + ).first() + + assert rematerialized.pk == loc.pk + assert rematerialized.coordinates.latitude == latitude + assert rematerialized.coordinates.longitude == longitude + + +@py_test_mark_asyncio +async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis): + class Location(HashModel, index=True): + coordinates: Coordinates = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc = Location(coordinates=(latitude, longitude)) + + await loc.save() + + with pytest.raises(NotFoundError): + await Location.find( + Location.coordinates + == GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi") + ).first() + + +@py_test_mark_asyncio +async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis): + class Location(HashModel, index=True): + coordinates: Optional[Coordinates] = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + loc = Location(coordinates=None) + + await loc.save() + + with pytest.raises(NotFoundError): + await Location.find( + Location.coordinates + == GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi") + ).first() + + +@py_test_mark_asyncio +async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis): + class Location(HashModel, index=True): + coordinates: Coordinates = Field(index=True) + name: str = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc1 = Location(coordinates=(latitude, longitude), name="Portland") + # Offset by 0.01 degrees (~1.1 km at this latitude) to create a nearby location + # This ensures "Nearby" is within the 10 mile search radius but not at the exact same location + loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby") + + await loc1.save() + await loc2.save() + + rematerialized: List[Location] = await Location.find( + ( + Location.coordinates + == GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi") + ) + & (Location.name == "Portland") + ).all() + + assert len(rematerialized) == 1 + assert rematerialized[0].pk == loc1.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 44ae9c61..c8bd4031 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -14,8 +14,10 @@ import pytest_asyncio from aredis_om import ( + Coordinates, EmbeddedJsonModel, Field, + GeoFilter, JsonModel, Migrator, NotFoundError, @@ -1364,3 +1366,113 @@ class Meta: rematerialized = await Model.find(Model.first_name == "Steve").first() assert rematerialized.pk == model.pk + + +@py_test_mark_asyncio +async def test_can_search_on_coordinates(key_prefix, redis): + class Location(JsonModel, index=True): + coordinates: Coordinates = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc = Location(coordinates=(latitude, longitude)) + + await loc.save() + + rematerialized: Location = await Location.find( + Location.coordinates + == GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi") + ).first() + + assert rematerialized.pk == loc.pk + assert rematerialized.coordinates.latitude == latitude + assert rematerialized.coordinates.longitude == longitude + + +@py_test_mark_asyncio +async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis): + class Location(JsonModel, index=True): + coordinates: Coordinates = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc = Location(coordinates=(latitude, longitude)) + + await loc.save() + + with pytest.raises(NotFoundError): + await Location.find( + Location.coordinates + == GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi") + ).first() + + +@py_test_mark_asyncio +async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis): + class Location(JsonModel, index=True): + coordinates: Optional[Coordinates] = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + loc = Location(coordinates=None) + + await loc.save() + + with pytest.raises(NotFoundError): + await Location.find( + Location.coordinates + == GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi") + ).first() + + +@py_test_mark_asyncio +async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis): + class Location(JsonModel, index=True): + coordinates: Coordinates = Field(index=True) + name: str = Field(index=True) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + latitude = 45.5231 + longitude = -122.6765 + + loc1 = Location(coordinates=(latitude, longitude), name="Portland") + # Offset by 0.01 degrees (~1.1 km at this latitude) to create a nearby location + # This ensures "Nearby" is within the 10 mile search radius but not at the exact same location + loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby") + + await loc1.save() + await loc2.save() + + rematerialized: List[Location] = await Location.find( + ( + Location.coordinates + == GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi") + ) + & (Location.name == "Portland") + ).all() + + assert len(rematerialized) == 1 + assert rematerialized[0].pk == loc1.pk