Skip to content

Commit 72dd2d0

Browse files
authored
Merge branch 'main' into dependabot/github_actions/codecov/codecov-action-5
2 parents a86b3c2 + b00c9e0 commit 72dd2d0

File tree

10 files changed

+464
-30
lines changed

10 files changed

+464
-30
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
strategy:
7777
matrix:
7878
os: [ ubuntu-latest ]
79-
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
79+
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
8080
redisstack: [ "latest" ]
8181
fail-fast: false
8282
services:

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,9 @@ tests_sync/
143143
# spelling cruft
144144
*.dic
145145

146-
.idea
146+
.idea
147+
148+
# version files
149+
.tool-versions
150+
151+
.vscode/

aredis_om/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
RedisModelError,
1717
VectorFieldOptions,
1818
)
19+
from .model.types import Coordinates, GeoFilter

aredis_om/model/model.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import abc
22
import dataclasses
3-
import decimal
43
import json
54
import logging
65
import operator
@@ -41,10 +40,11 @@
4140
from .. import redis
4241
from ..checks import has_redis_json, has_redisearch
4342
from ..connections import get_redis_connection
44-
from ..util import ASYNC_MODE
43+
from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type
4544
from .encoders import jsonable_encoder
4645
from .render_tree import render_tree
4746
from .token_escaper import TokenEscaper
47+
from .types import Coordinates, CoordinateType, GeoFilter
4848

4949

5050
model_registry = {}
@@ -405,8 +405,6 @@ class RediSearchFieldTypes(Enum):
405405
GEO = "GEO"
406406

407407

408-
# TODO: How to handle Geo fields?
409-
NUMERIC_TYPES = (float, int, decimal.Decimal)
410408
DEFAULT_PAGE_SIZE = 1000
411409

412410

@@ -536,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
536534
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
537535
field_info: Union[FieldInfo, PydanticFieldInfo] = field
538536

537+
typ = get_outer_type(field_info)
538+
539539
if getattr(field_info, "primary_key", None) is True:
540540
return RediSearchFieldTypes.TAG
541+
elif typ in [CoordinateType, Coordinates]:
542+
return RediSearchFieldTypes.GEO
541543
elif op is Operators.LIKE:
542544
fts = getattr(field_info, "full_text_search", None)
543545
if fts is not True: # Could be PydanticUndefined
@@ -553,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
553555
if not isinstance(field_type, type):
554556
field_type = field_type.__origin__
555557

556-
# TODO: GEO fields
557558
container_type = get_origin(field_type)
558559

559560
if is_supported_container_type(container_type):
@@ -578,7 +579,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
578579
)
579580
elif field_type is bool:
580581
return RediSearchFieldTypes.TAG
581-
elif any(issubclass(field_type, t) for t in NUMERIC_TYPES):
582+
elif is_numeric_type(field_type):
582583
# Index numeric Python types as NUMERIC fields, so we can support
583584
# range queries.
584585
return RediSearchFieldTypes.NUMERIC
@@ -727,6 +728,15 @@ def resolve_value(
727728
field_name=field_name, expanded_value=expanded_value
728729
)
729730

731+
elif field_type is RediSearchFieldTypes.GEO:
732+
if not isinstance(value, GeoFilter):
733+
raise QuerySyntaxError(
734+
"You can only use a GeoFilter object with a GEO field."
735+
)
736+
737+
if op is Operators.EQ:
738+
result += f"@{field_name}:[{value}]"
739+
730740
return result
731741

732742
def resolve_redisearch_pagination(self):
@@ -1378,12 +1388,14 @@ def outer_type_or_annotation(field: FieldInfo):
13781388
def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool:
13791389
# for vector, full text search, and sortable fields, we always have to index
13801390
# We could require the user to set index=True, but that would be a breaking change
1381-
index = getattr(field_info, "index", None) is True
1391+
_index = getattr(field_info, "index", None)
1392+
1393+
index = _index is True
13821394
vector_options = getattr(field_info, "vector_options", None) is not None
13831395
full_text_search = getattr(field_info, "full_text_search", None) is True
13841396
sortable = getattr(field_info, "sortable", None) is True
13851397

1386-
if index is False and any([vector_options, full_text_search, sortable]):
1398+
if _index is False and any([vector_options, full_text_search, sortable]):
13871399
log.warning(
13881400
"Field is marked as index=False, but it is a vector, full text search, or sortable field. "
13891401
"This will be ignored and the field will be indexed.",
@@ -1803,7 +1815,9 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
18031815
schema = cls.schema_for_type(name, embedded_cls, field_info)
18041816
elif typ is bool:
18051817
schema = f"{name} TAG"
1806-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
1818+
elif typ in [CoordinateType, Coordinates]:
1819+
schema = f"{name} GEO"
1820+
elif is_numeric_type(typ):
18071821
vector_options: Optional[VectorFieldOptions] = getattr(
18081822
field_info, "vector_options", None
18091823
)
@@ -1965,7 +1979,7 @@ def schema_for_type(
19651979
json_path: str,
19661980
name: str,
19671981
name_prefix: str,
1968-
typ: Union[type[RedisModel], Any],
1982+
typ: Union[Type[RedisModel], Any],
19691983
field_info: PydanticFieldInfo,
19701984
parent_type: Optional[Any] = None,
19711985
) -> str:
@@ -2002,9 +2016,7 @@ def schema_for_type(
20022016
field_info, "vector_options", None
20032017
)
20042018
try:
2005-
is_vector = vector_options and any(
2006-
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
2007-
)
2019+
is_vector = vector_options and has_numeric_inner_type(typ)
20082020
except IndexError:
20092021
raise RedisModelError(
20102022
f"Vector field '{name}' must be annotated as a container type"
@@ -2102,9 +2114,12 @@ def schema_for_type(
21022114
# a proper type, we can pull the type information from the origin of the first argument.
21032115
if not isinstance(typ, type):
21042116
type_args = typing_get_args(field_info.annotation)
2105-
typ = type_args[0].__origin__
2117+
typ = (
2118+
getattr(type_args[0], "__origin__", type_args[0])
2119+
if type_args
2120+
else typ
2121+
)
21062122

2107-
# TODO: GEO field
21082123
if is_vector and vector_options:
21092124
schema = f"{path} AS {index_field_name} {vector_options.schema}"
21102125
elif parent_is_container_type or parent_is_model_in_container:
@@ -2125,7 +2140,9 @@ def schema_for_type(
21252140
schema += " CASESENSITIVE"
21262141
elif typ is bool:
21272142
schema = f"{path} AS {index_field_name} TAG"
2128-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
2143+
elif typ in [CoordinateType, Coordinates]:
2144+
schema = f"{path} AS {index_field_name} GEO"
2145+
elif is_numeric_type(typ):
21292146
schema = f"{path} AS {index_field_name} NUMERIC"
21302147
elif issubclass(typ, str):
21312148
if full_text_search is True:

aredis_om/model/types.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Annotated, Any, Literal, Tuple, Union
2+
3+
from pydantic import BeforeValidator, PlainSerializer
4+
from pydantic_extra_types.coordinate import Coordinate
5+
6+
7+
RadiusUnit = Literal["m", "km", "mi", "ft"]
8+
9+
10+
class GeoFilter:
11+
"""
12+
A geographic filter for searching within a radius of a coordinate point.
13+
14+
This filter is used with GEO fields to find models within a specified
15+
distance from a given location.
16+
17+
Args:
18+
longitude: The longitude of the center point (-180 to 180)
19+
latitude: The latitude of the center point (-90 to 90)
20+
radius: The search radius (must be positive)
21+
unit: The unit of measurement ('m', 'km', 'mi', or 'ft')
22+
23+
Example:
24+
>>> # Find all locations within 10 miles of Portland, OR
25+
>>> filter = GeoFilter(
26+
... longitude=-122.6765,
27+
... latitude=45.5231,
28+
... radius=10,
29+
... unit="mi"
30+
... )
31+
>>> results = await Location.find(
32+
... Location.coordinates == filter
33+
... ).all()
34+
"""
35+
36+
def __init__(
37+
self, longitude: float, latitude: float, radius: float, unit: RadiusUnit
38+
):
39+
# Validate coordinates
40+
if not -180 <= longitude <= 180:
41+
raise ValueError(f"Longitude must be between -180 and 180, got {longitude}")
42+
if not -90 <= latitude <= 90:
43+
raise ValueError(f"Latitude must be between -90 and 90, got {latitude}")
44+
if radius <= 0:
45+
raise ValueError(f"Radius must be positive, got {radius}")
46+
47+
self.longitude = longitude
48+
self.latitude = latitude
49+
self.radius = radius
50+
self.unit = unit
51+
52+
def __str__(self) -> str:
53+
return f"{self.longitude} {self.latitude} {self.radius} {self.unit}"
54+
55+
@classmethod
56+
def from_coordinates(
57+
cls, coords: Coordinate, radius: float, unit: RadiusUnit
58+
) -> "GeoFilter":
59+
"""
60+
Create a GeoFilter from a Coordinates object.
61+
62+
Args:
63+
coords: A Coordinate object with latitude and longitude
64+
radius: The search radius
65+
unit: The unit of measurement
66+
67+
Returns:
68+
A new GeoFilter instance
69+
"""
70+
return cls(coords.longitude, coords.latitude, radius, unit)
71+
72+
73+
CoordinateType = Coordinate
74+
75+
76+
def parse_redis(v: Any) -> Union[Tuple[str, str], Any]:
77+
"""
78+
Transform Redis coordinate format to Pydantic coordinate format.
79+
80+
The pydantic coordinate type expects a string in the format 'latitude,longitude'.
81+
Redis stores coordinates in the format 'longitude,latitude'.
82+
83+
This validator transforms the input from Redis into the expected format for pydantic.
84+
85+
Args:
86+
v: The value from Redis (typically a string like "longitude,latitude")
87+
88+
Returns:
89+
A tuple of (latitude, longitude) strings if input is a coordinate string,
90+
otherwise returns the input unchanged.
91+
92+
Raises:
93+
ValueError: If the coordinate string format is invalid
94+
"""
95+
if isinstance(v, str):
96+
parts = v.split(",")
97+
98+
if len(parts) != 2:
99+
raise ValueError(
100+
f"Invalid coordinate format. Expected 'longitude,latitude' but got: {v}"
101+
)
102+
103+
return (parts[1], parts[0]) # Swap to (latitude, longitude)
104+
105+
return v
106+
107+
108+
Coordinates = Annotated[
109+
CoordinateType,
110+
PlainSerializer(
111+
lambda v: f"{v.longitude},{v.latitude}",
112+
return_type=str,
113+
when_used="unless-none",
114+
),
115+
BeforeValidator(parse_redis),
116+
]

aredis_om/util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import decimal
12
import inspect
3+
from typing import Any, Type, get_args
24

35

46
def is_async_mode() -> bool:
@@ -10,3 +12,27 @@ async def f() -> None:
1012

1113

1214
ASYNC_MODE = is_async_mode()
15+
16+
NUMERIC_TYPES = (float, int, decimal.Decimal)
17+
18+
19+
def is_numeric_type(type_: Type[Any]) -> bool:
20+
try:
21+
return issubclass(type_, NUMERIC_TYPES)
22+
except TypeError:
23+
return False
24+
25+
26+
def has_numeric_inner_type(type_: Type[Any]) -> bool:
27+
"""
28+
Check if the type has a numeric inner type.
29+
"""
30+
args = get_args(type_)
31+
32+
if not args:
33+
return False
34+
35+
try:
36+
return issubclass(args[0], NUMERIC_TYPES)
37+
except TypeError:
38+
return False

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redis-om"
3-
version = "1.0.0-beta"
3+
version = "1.0.3-beta"
44
description = "Object mappings, and more, for Redis."
55
authors = ["Redis OSS <[email protected]>"]
66
maintainers = ["Redis OSS <[email protected]>"]
@@ -22,6 +22,7 @@ classifiers = [
2222
'Programming Language :: Python :: 3.10',
2323
'Programming Language :: Python :: 3.11',
2424
'Programming Language :: Python :: 3.12',
25+
'Programming Language :: Python :: 3.13',
2526
'Programming Language :: Python',
2627
]
2728
include=[
@@ -36,7 +37,7 @@ include=[
3637

3738
[tool.poetry.dependencies]
3839
python = ">=3.8,<4.0"
39-
redis = ">=3.5.3,<6.0.0"
40+
redis = ">=3.5.3,<7.0.0"
4041
pydantic = ">=2.0.0,<3.0.0"
4142
click = "^8.0.1"
4243
types-redis = ">=3.5.9,<5.0.0"
@@ -45,8 +46,9 @@ typing-extensions = "^4.4.0"
4546
hiredis = ">=2.2.3,<4.0.0"
4647
more-itertools = ">=8.14,<11.0"
4748
setuptools = ">=70.0"
49+
pydantic-extra-types = "^2.10.5"
4850

49-
[tool.poetry.dev-dependencies]
51+
[tool.poetry.group.dev.dependencies]
5052
mypy = "^1.9.0"
5153
pytest = "^8.0.2"
5254
ipdb = "^0.13.9"

0 commit comments

Comments
 (0)