Skip to content

Add Geo Filtering #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ tests_sync/

# version files
.tool-versions

.vscode/
1 change: 1 addition & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
RedisModelError,
VectorFieldOptions,
)
from .model.types import Coordinates, GeoFilter
22 changes: 18 additions & 4 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import dataclasses
import decimal
import json
import logging
import operator
Expand Down Expand Up @@ -45,6 +44,7 @@
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
from .types import Coordinates, CoordinateType, GeoFilter


model_registry = {}
Expand Down Expand Up @@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum):
GEO = "GEO"


# TODO: How to handle Geo fields?
DEFAULT_PAGE_SIZE = 1000


Expand Down Expand Up @@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
field_info: Union[FieldInfo, PydanticFieldInfo] = field

typ = get_outer_type(field_info)

if getattr(field_info, "primary_key", None) is True:
return RediSearchFieldTypes.TAG
elif typ in [CoordinateType, Coordinates]:
return RediSearchFieldTypes.GEO
elif op is Operators.LIKE:
fts = getattr(field_info, "full_text_search", None)
if fts is not True: # Could be PydanticUndefined
Expand All @@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
if not isinstance(field_type, type):
field_type = field_type.__origin__

# TODO: GEO fields
container_type = get_origin(field_type)

if is_supported_container_type(container_type):
Expand Down Expand Up @@ -726,6 +728,15 @@ def resolve_value(
field_name=field_name, expanded_value=expanded_value
)

elif field_type is RediSearchFieldTypes.GEO:
if not isinstance(value, GeoFilter):
raise QuerySyntaxError(
"You can only use a GeoFilter object with a GEO field."
)

if op is Operators.EQ:
result += f"@{field_name}:[{value}]"

return result

def resolve_redisearch_pagination(self):
Expand Down Expand Up @@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif typ is bool:
schema = f"{name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{name} GEO"
elif is_numeric_type(typ):
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
Expand Down Expand Up @@ -2107,7 +2120,6 @@ def schema_for_type(
else typ
)

# TODO: GEO field
if is_vector and vector_options:
schema = f"{path} AS {index_field_name} {vector_options.schema}"
elif parent_is_container_type or parent_is_model_in_container:
Expand All @@ -2128,6 +2140,8 @@ def schema_for_type(
schema += " CASESENSITIVE"
elif typ is bool:
schema = f"{path} AS {index_field_name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{path} AS {index_field_name} GEO"
elif is_numeric_type(typ):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
Expand Down
116 changes: 116 additions & 0 deletions aredis_om/model/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Annotated, Any, Literal, Tuple, Union

from pydantic import BeforeValidator, PlainSerializer
from pydantic_extra_types.coordinate import Coordinate


RadiusUnit = Literal["m", "km", "mi", "ft"]


class GeoFilter:
"""
A geographic filter for searching within a radius of a coordinate point.

This filter is used with GEO fields to find models within a specified
distance from a given location.

Args:
longitude: The longitude of the center point (-180 to 180)
latitude: The latitude of the center point (-90 to 90)
radius: The search radius (must be positive)
unit: The unit of measurement ('m', 'km', 'mi', or 'ft')

Example:
>>> # Find all locations within 10 miles of Portland, OR
>>> filter = GeoFilter(
... longitude=-122.6765,
... latitude=45.5231,
... radius=10,
... unit="mi"
... )
>>> results = await Location.find(
... Location.coordinates == filter
... ).all()
"""

def __init__(
self, longitude: float, latitude: float, radius: float, unit: RadiusUnit
):
# Validate coordinates
if not -180 <= longitude <= 180:
raise ValueError(f"Longitude must be between -180 and 180, got {longitude}")
if not -90 <= latitude <= 90:
raise ValueError(f"Latitude must be between -90 and 90, got {latitude}")
if radius <= 0:
raise ValueError(f"Radius must be positive, got {radius}")

self.longitude = longitude
self.latitude = latitude
self.radius = radius
self.unit = unit

def __str__(self) -> str:
return f"{self.longitude} {self.latitude} {self.radius} {self.unit}"

@classmethod
def from_coordinates(
cls, coords: Coordinate, radius: float, unit: RadiusUnit
) -> "GeoFilter":
"""
Create a GeoFilter from a Coordinates object.

Args:
coords: A Coordinate object with latitude and longitude
radius: The search radius
unit: The unit of measurement

Returns:
A new GeoFilter instance
"""
return cls(coords.longitude, coords.latitude, radius, unit)


CoordinateType = Coordinate


def parse_redis(v: Any) -> Union[Tuple[str, str], Any]:
"""
Transform Redis coordinate format to Pydantic coordinate format.

The pydantic coordinate type expects a string in the format 'latitude,longitude'.
Redis stores coordinates in the format 'longitude,latitude'.

This validator transforms the input from Redis into the expected format for pydantic.

Args:
v: The value from Redis (typically a string like "longitude,latitude")

Returns:
A tuple of (latitude, longitude) strings if input is a coordinate string,
otherwise returns the input unchanged.

Raises:
ValueError: If the coordinate string format is invalid
"""
if isinstance(v, str):
parts = v.split(",")

if len(parts) != 2:
raise ValueError(
f"Invalid coordinate format. Expected 'longitude,latitude' but got: {v}"
)

return (parts[1], parts[0]) # Swap to (latitude, longitude)

return v


Coordinates = Annotated[
CoordinateType,
PlainSerializer(
lambda v: f"{v.longitude},{v.latitude}",
return_type=str,
when_used="unless-none",
),
BeforeValidator(parse_redis),
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.2-beta"
version = "1.0.3-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand Down Expand Up @@ -46,6 +46,7 @@ typing-extensions = "^4.4.0"
hiredis = ">=2.2.3,<4.0.0"
more-itertools = ">=8.14,<11.0"
setuptools = ">=70.0"
pydantic-extra-types = "^2.10.5"

[tool.poetry.group.dev.dependencies]
mypy = "^1.9.0"
Expand Down
112 changes: 112 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import pytest_asyncio

from aredis_om import (
Coordinates,
Field,
GeoFilter,
HashModel,
Migrator,
NotFoundError,
Expand Down Expand Up @@ -1054,3 +1056,113 @@ class Meta:

rematerialized = await Model.find(Model.first_name == "Steve").first()
assert rematerialized.pk == model.pk


@py_test_mark_asyncio
async def test_can_search_on_coordinates(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

rematerialized: Location = await Location.find(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
).first()

assert rematerialized.pk == loc.pk
assert rematerialized.coordinates.latitude == latitude
assert rematerialized.coordinates.longitude == longitude


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Optional[Coordinates] = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

loc = Location(coordinates=None)

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)
name: str = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc1 = Location(coordinates=(latitude, longitude), name="Portland")
# Offset by 0.01 degrees (~1.1 km at this latitude) to create a nearby location
# This ensures "Nearby" is within the 10 mile search radius but not at the exact same location
loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby")

await loc1.save()
await loc2.save()

rematerialized: List[Location] = await Location.find(
(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
)
& (Location.name == "Portland")
).all()

assert len(rematerialized) == 1
assert rematerialized[0].pk == loc1.pk
Loading