Skip to content

Add Geo Filtering #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ tests_sync/

# version files
.tool-versions

.vscode/
1 change: 1 addition & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
RedisModelError,
VectorFieldOptions,
)
from .model.types import Coordinates, GeoFilter
22 changes: 18 additions & 4 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import dataclasses
import decimal
import json
import logging
import operator
Expand Down Expand Up @@ -45,6 +44,7 @@
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
from .types import Coordinates, CoordinateType, GeoFilter


model_registry = {}
Expand Down Expand Up @@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum):
GEO = "GEO"


# TODO: How to handle Geo fields?
DEFAULT_PAGE_SIZE = 1000


Expand Down Expand Up @@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
field_info: Union[FieldInfo, PydanticFieldInfo] = field

typ = get_outer_type(field_info)

if getattr(field_info, "primary_key", None) is True:
return RediSearchFieldTypes.TAG
elif typ in [CoordinateType, Coordinates]:
return RediSearchFieldTypes.GEO
elif op is Operators.LIKE:
fts = getattr(field_info, "full_text_search", None)
if fts is not True: # Could be PydanticUndefined
Expand All @@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
if not isinstance(field_type, type):
field_type = field_type.__origin__

# TODO: GEO fields
container_type = get_origin(field_type)

if is_supported_container_type(container_type):
Expand Down Expand Up @@ -726,6 +728,15 @@ def resolve_value(
field_name=field_name, expanded_value=expanded_value
)

elif field_type is RediSearchFieldTypes.GEO:
if not isinstance(value, GeoFilter):
raise QuerySyntaxError(
"You can only use a GeoFilter object with a GEO field."
)

if op is Operators.EQ:
result += f"@{field_name}:[{value}]"

return result

def resolve_redisearch_pagination(self):
Expand Down Expand Up @@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif typ is bool:
schema = f"{name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{name} GEO"
elif is_numeric_type(typ):
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
Expand Down Expand Up @@ -2107,7 +2120,6 @@ def schema_for_type(
else typ
)

# TODO: GEO field
if is_vector and vector_options:
schema = f"{path} AS {index_field_name} {vector_options.schema}"
elif parent_is_container_type or parent_is_model_in_container:
Expand All @@ -2128,6 +2140,8 @@ def schema_for_type(
schema += " CASESENSITIVE"
elif typ is bool:
schema = f"{path} AS {index_field_name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{path} AS {index_field_name} GEO"
elif is_numeric_type(typ):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
Expand Down
50 changes: 50 additions & 0 deletions aredis_om/model/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Annotated, Any, Literal

from pydantic import BeforeValidator, PlainSerializer
from pydantic_extra_types.coordinate import Coordinate


RadiusUnit = Literal["m", "km", "mi", "ft"]


class GeoFilter:
def __init__(self, longitude: float, latitude: float, radius: float, unit: RadiusUnit):
self.longitude = longitude
self.latitude = latitude
self.radius = radius
self.unit = unit

def __str__(self):
return f"{self.longitude} {self.latitude} {self.radius} {self.unit}"


CoordinateType = Coordinate


def parse_redis(v: Any):
"""
The pydantic coordinate type expects a string in the format 'latitude,longitude'.
Redis expects a string in the format 'longitude,latitude'.

This validator transforms the input from Redis into the expected format for pydantic.
"""
if isinstance(v, str):
parts = v.split(",")

if len(parts) != 2:
raise ValueError("Invalid coordinate format")

return (parts[1], parts[0])

return v


Coordinates = Annotated[
CoordinateType,
PlainSerializer(
lambda v: f"{v.longitude},{v.latitude}",
return_type=str,
when_used="unless-none",
),
BeforeValidator(parse_redis),
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.2-beta"
version = "1.0.3-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand Down Expand Up @@ -46,6 +46,7 @@ typing-extensions = "^4.4.0"
hiredis = ">=2.2.3,<4.0.0"
more-itertools = ">=8.14,<11.0"
setuptools = ">=70.0"
pydantic-extra-types = "^2.10.5"

[tool.poetry.group.dev.dependencies]
mypy = "^1.9.0"
Expand Down
110 changes: 110 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import pytest_asyncio

from aredis_om import (
Coordinates,
Field,
GeoFilter,
HashModel,
Migrator,
NotFoundError,
Expand Down Expand Up @@ -1054,3 +1056,111 @@ class Meta:

rematerialized = await Model.find(Model.first_name == "Steve").first()
assert rematerialized.pk == model.pk


@py_test_mark_asyncio
async def test_can_search_on_coordinates(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

rematerialized: Location = await Location.find(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
).first()

assert rematerialized.pk == loc.pk
assert rematerialized.coordinates.latitude == latitude
assert rematerialized.coordinates.longitude == longitude


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Optional[Coordinates] = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

loc = Location(coordinates=None)

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)
name: str = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc1 = Location(coordinates=(latitude, longitude), name="Portland")
loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby")

await loc1.save()
await loc2.save()

rematerialized: List[Location] = await Location.find(
(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
)
& (Location.name == "Portland")
).all()

assert len(rematerialized) == 1
assert rematerialized[0].pk == loc1.pk
Loading