Skip to content

Commit b8678d9

Browse files
authored
Merge branch 'main' into dependabot/github_actions/rojopolis/spellcheck-github-actions-0.48.0
2 parents ec67f34 + b00c9e0 commit b8678d9

File tree

7 files changed

+363
-5
lines changed

7 files changed

+363
-5
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,5 @@ tests_sync/
147147

148148
# version files
149149
.tool-versions
150+
151+
.vscode/

aredis_om/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
RedisModelError,
1717
VectorFieldOptions,
1818
)
19+
from .model.types import Coordinates, GeoFilter

aredis_om/model/model.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import abc
22
import dataclasses
3-
import decimal
43
import json
54
import logging
65
import operator
@@ -45,6 +44,7 @@
4544
from .encoders import jsonable_encoder
4645
from .render_tree import render_tree
4746
from .token_escaper import TokenEscaper
47+
from .types import Coordinates, CoordinateType, GeoFilter
4848

4949

5050
model_registry = {}
@@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum):
405405
GEO = "GEO"
406406

407407

408-
# TODO: How to handle Geo fields?
409408
DEFAULT_PAGE_SIZE = 1000
410409

411410

@@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
535534
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
536535
field_info: Union[FieldInfo, PydanticFieldInfo] = field
537536

537+
typ = get_outer_type(field_info)
538+
538539
if getattr(field_info, "primary_key", None) is True:
539540
return RediSearchFieldTypes.TAG
541+
elif typ in [CoordinateType, Coordinates]:
542+
return RediSearchFieldTypes.GEO
540543
elif op is Operators.LIKE:
541544
fts = getattr(field_info, "full_text_search", None)
542545
if fts is not True: # Could be PydanticUndefined
@@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
552555
if not isinstance(field_type, type):
553556
field_type = field_type.__origin__
554557

555-
# TODO: GEO fields
556558
container_type = get_origin(field_type)
557559

558560
if is_supported_container_type(container_type):
@@ -726,6 +728,15 @@ def resolve_value(
726728
field_name=field_name, expanded_value=expanded_value
727729
)
728730

731+
elif field_type is RediSearchFieldTypes.GEO:
732+
if not isinstance(value, GeoFilter):
733+
raise QuerySyntaxError(
734+
"You can only use a GeoFilter object with a GEO field."
735+
)
736+
737+
if op is Operators.EQ:
738+
result += f"@{field_name}:[{value}]"
739+
729740
return result
730741

731742
def resolve_redisearch_pagination(self):
@@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
18041815
schema = cls.schema_for_type(name, embedded_cls, field_info)
18051816
elif typ is bool:
18061817
schema = f"{name} TAG"
1818+
elif typ in [CoordinateType, Coordinates]:
1819+
schema = f"{name} GEO"
18071820
elif is_numeric_type(typ):
18081821
vector_options: Optional[VectorFieldOptions] = getattr(
18091822
field_info, "vector_options", None
@@ -2107,7 +2120,6 @@ def schema_for_type(
21072120
else typ
21082121
)
21092122

2110-
# TODO: GEO field
21112123
if is_vector and vector_options:
21122124
schema = f"{path} AS {index_field_name} {vector_options.schema}"
21132125
elif parent_is_container_type or parent_is_model_in_container:
@@ -2128,6 +2140,8 @@ def schema_for_type(
21282140
schema += " CASESENSITIVE"
21292141
elif typ is bool:
21302142
schema = f"{path} AS {index_field_name} TAG"
2143+
elif typ in [CoordinateType, Coordinates]:
2144+
schema = f"{path} AS {index_field_name} GEO"
21312145
elif is_numeric_type(typ):
21322146
schema = f"{path} AS {index_field_name} NUMERIC"
21332147
elif issubclass(typ, str):

aredis_om/model/types.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Annotated, Any, Literal, Tuple, Union
2+
3+
from pydantic import BeforeValidator, PlainSerializer
4+
from pydantic_extra_types.coordinate import Coordinate
5+
6+
7+
RadiusUnit = Literal["m", "km", "mi", "ft"]
8+
9+
10+
class GeoFilter:
11+
"""
12+
A geographic filter for searching within a radius of a coordinate point.
13+
14+
This filter is used with GEO fields to find models within a specified
15+
distance from a given location.
16+
17+
Args:
18+
longitude: The longitude of the center point (-180 to 180)
19+
latitude: The latitude of the center point (-90 to 90)
20+
radius: The search radius (must be positive)
21+
unit: The unit of measurement ('m', 'km', 'mi', or 'ft')
22+
23+
Example:
24+
>>> # Find all locations within 10 miles of Portland, OR
25+
>>> filter = GeoFilter(
26+
... longitude=-122.6765,
27+
... latitude=45.5231,
28+
... radius=10,
29+
... unit="mi"
30+
... )
31+
>>> results = await Location.find(
32+
... Location.coordinates == filter
33+
... ).all()
34+
"""
35+
36+
def __init__(
37+
self, longitude: float, latitude: float, radius: float, unit: RadiusUnit
38+
):
39+
# Validate coordinates
40+
if not -180 <= longitude <= 180:
41+
raise ValueError(f"Longitude must be between -180 and 180, got {longitude}")
42+
if not -90 <= latitude <= 90:
43+
raise ValueError(f"Latitude must be between -90 and 90, got {latitude}")
44+
if radius <= 0:
45+
raise ValueError(f"Radius must be positive, got {radius}")
46+
47+
self.longitude = longitude
48+
self.latitude = latitude
49+
self.radius = radius
50+
self.unit = unit
51+
52+
def __str__(self) -> str:
53+
return f"{self.longitude} {self.latitude} {self.radius} {self.unit}"
54+
55+
@classmethod
56+
def from_coordinates(
57+
cls, coords: Coordinate, radius: float, unit: RadiusUnit
58+
) -> "GeoFilter":
59+
"""
60+
Create a GeoFilter from a Coordinates object.
61+
62+
Args:
63+
coords: A Coordinate object with latitude and longitude
64+
radius: The search radius
65+
unit: The unit of measurement
66+
67+
Returns:
68+
A new GeoFilter instance
69+
"""
70+
return cls(coords.longitude, coords.latitude, radius, unit)
71+
72+
73+
CoordinateType = Coordinate
74+
75+
76+
def parse_redis(v: Any) -> Union[Tuple[str, str], Any]:
77+
"""
78+
Transform Redis coordinate format to Pydantic coordinate format.
79+
80+
The pydantic coordinate type expects a string in the format 'latitude,longitude'.
81+
Redis stores coordinates in the format 'longitude,latitude'.
82+
83+
This validator transforms the input from Redis into the expected format for pydantic.
84+
85+
Args:
86+
v: The value from Redis (typically a string like "longitude,latitude")
87+
88+
Returns:
89+
A tuple of (latitude, longitude) strings if input is a coordinate string,
90+
otherwise returns the input unchanged.
91+
92+
Raises:
93+
ValueError: If the coordinate string format is invalid
94+
"""
95+
if isinstance(v, str):
96+
parts = v.split(",")
97+
98+
if len(parts) != 2:
99+
raise ValueError(
100+
f"Invalid coordinate format. Expected 'longitude,latitude' but got: {v}"
101+
)
102+
103+
return (parts[1], parts[0]) # Swap to (latitude, longitude)
104+
105+
return v
106+
107+
108+
Coordinates = Annotated[
109+
CoordinateType,
110+
PlainSerializer(
111+
lambda v: f"{v.longitude},{v.latitude}",
112+
return_type=str,
113+
when_used="unless-none",
114+
),
115+
BeforeValidator(parse_redis),
116+
]

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redis-om"
3-
version = "1.0.2-beta"
3+
version = "1.0.3-beta"
44
description = "Object mappings, and more, for Redis."
55
authors = ["Redis OSS <[email protected]>"]
66
maintainers = ["Redis OSS <[email protected]>"]
@@ -46,6 +46,7 @@ typing-extensions = "^4.4.0"
4646
hiredis = ">=2.2.3,<4.0.0"
4747
more-itertools = ">=8.14,<11.0"
4848
setuptools = ">=70.0"
49+
pydantic-extra-types = "^2.10.5"
4950

5051
[tool.poetry.group.dev.dependencies]
5152
mypy = "^1.9.0"

tests/test_hash_model.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import pytest_asyncio
1414

1515
from aredis_om import (
16+
Coordinates,
1617
Field,
18+
GeoFilter,
1719
HashModel,
1820
Migrator,
1921
NotFoundError,
@@ -1054,3 +1056,113 @@ class Meta:
10541056

10551057
rematerialized = await Model.find(Model.first_name == "Steve").first()
10561058
assert rematerialized.pk == model.pk
1059+
1060+
1061+
@py_test_mark_asyncio
1062+
async def test_can_search_on_coordinates(key_prefix, redis):
1063+
class Location(HashModel, index=True):
1064+
coordinates: Coordinates = Field(index=True)
1065+
1066+
class Meta:
1067+
global_key_prefix = key_prefix
1068+
database = redis
1069+
1070+
await Migrator().run()
1071+
1072+
latitude = 45.5231
1073+
longitude = -122.6765
1074+
1075+
loc = Location(coordinates=(latitude, longitude))
1076+
1077+
await loc.save()
1078+
1079+
rematerialized: Location = await Location.find(
1080+
Location.coordinates
1081+
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
1082+
).first()
1083+
1084+
assert rematerialized.pk == loc.pk
1085+
assert rematerialized.coordinates.latitude == latitude
1086+
assert rematerialized.coordinates.longitude == longitude
1087+
1088+
1089+
@py_test_mark_asyncio
1090+
async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis):
1091+
class Location(HashModel, index=True):
1092+
coordinates: Coordinates = Field(index=True)
1093+
1094+
class Meta:
1095+
global_key_prefix = key_prefix
1096+
database = redis
1097+
1098+
await Migrator().run()
1099+
1100+
latitude = 45.5231
1101+
longitude = -122.6765
1102+
1103+
loc = Location(coordinates=(latitude, longitude))
1104+
1105+
await loc.save()
1106+
1107+
with pytest.raises(NotFoundError):
1108+
await Location.find(
1109+
Location.coordinates
1110+
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
1111+
).first()
1112+
1113+
1114+
@py_test_mark_asyncio
1115+
async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis):
1116+
class Location(HashModel, index=True):
1117+
coordinates: Optional[Coordinates] = Field(index=True)
1118+
1119+
class Meta:
1120+
global_key_prefix = key_prefix
1121+
database = redis
1122+
1123+
await Migrator().run()
1124+
1125+
loc = Location(coordinates=None)
1126+
1127+
await loc.save()
1128+
1129+
with pytest.raises(NotFoundError):
1130+
await Location.find(
1131+
Location.coordinates
1132+
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
1133+
).first()
1134+
1135+
1136+
@py_test_mark_asyncio
1137+
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
1138+
class Location(HashModel, index=True):
1139+
coordinates: Coordinates = Field(index=True)
1140+
name: str = Field(index=True)
1141+
1142+
class Meta:
1143+
global_key_prefix = key_prefix
1144+
database = redis
1145+
1146+
await Migrator().run()
1147+
1148+
latitude = 45.5231
1149+
longitude = -122.6765
1150+
1151+
loc1 = Location(coordinates=(latitude, longitude), name="Portland")
1152+
# Offset by 0.01 degrees (~1.1 km at this latitude) to create a nearby location
1153+
# This ensures "Nearby" is within the 10 mile search radius but not at the exact same location
1154+
loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby")
1155+
1156+
await loc1.save()
1157+
await loc2.save()
1158+
1159+
rematerialized: List[Location] = await Location.find(
1160+
(
1161+
Location.coordinates
1162+
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
1163+
)
1164+
& (Location.name == "Portland")
1165+
).all()
1166+
1167+
assert len(rematerialized) == 1
1168+
assert rematerialized[0].pk == loc1.pk

0 commit comments

Comments
 (0)