Skip to content

Commit e0f3b12

Browse files
committed
fix knn setup
1 parent 3676634 commit e0f3b12

File tree

8 files changed

+84
-55
lines changed

8 files changed

+84
-55
lines changed

aredis_om/__init__.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,3 @@
1616
RedisModelError,
1717
VectorFieldOptions,
1818
)
19-
20-
21-
__all__ = [
22-
"redis",
23-
"get_redis_connection",
24-
"Field",
25-
"HashModel",
26-
"JsonModel",
27-
"EmbeddedJsonModel",
28-
"RedisModel",
29-
"FindQuery",
30-
"KNNExpression",
31-
"VectorFieldOptions",
32-
"has_redis_json",
33-
"has_redisearch",
34-
"MigrationError",
35-
"Migrator",
36-
"RedisModelError",
37-
"NotFoundError",
38-
"QueryNotSupportedError",
39-
"QuerySyntaxError",
40-
]

aredis_om/async_redis.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
11
from redis import asyncio as redis
2-
3-
4-
__all__ = ["redis"]

aredis_om/model/encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pydantic.deprecated.json import ENCODERS_BY_TYPE
3636
from pydantic_core import PydanticUndefined
3737

38+
3839
SetIntStr = Set[Union[int, str]]
3940
DictIntStrAny = Dict[Union[int, str], Any]
4041

aredis_om/model/model.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -279,20 +279,24 @@ def tree(self):
279279
@dataclasses.dataclass
280280
class KNNExpression:
281281
k: int
282-
vector_field_name: str
283-
score_field_name: str
282+
vector_field: "ExpressionProxy"
283+
score_field: "ExpressionProxy"
284284
reference_vector: bytes
285285

286286
def __str__(self):
287-
return f"KNN $K @{self.vector_field_name} $knn_ref_vector"
287+
return f"KNN $K @{self.vector_field_name} $knn_ref_vector AS {self.score_field_name}"
288288

289289
@property
290290
def query_params(self) -> Dict[str, Union[str, bytes]]:
291291
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}
292292

293293
@property
294-
def score_field(self) -> str:
295-
return self.score_field_name or f"_{self.vector_field_name}_score"
294+
def score_field_name(self) -> str:
295+
return self.score_field.field.alias
296+
297+
@property
298+
def vector_field_name(self) -> str:
299+
return self.vector_field.field.alias
296300

297301

298302
ExpressionOrNegated = Union[Expression, NegatedExpression]
@@ -438,7 +442,7 @@ def __init__(
438442
if sort_fields:
439443
self.sort_fields = self.validate_sort_fields(sort_fields)
440444
elif self.knn:
441-
self.sort_fields = [self.knn.score_field]
445+
self.sort_fields = [self.knn.score_field_name]
442446
else:
443447
self.sort_fields = []
444448

@@ -511,7 +515,7 @@ def query_params(self):
511515
def validate_sort_fields(self, sort_fields: List[str]):
512516
for sort_field in sort_fields:
513517
field_name = sort_field.lstrip("-")
514-
if self.knn and field_name == self.knn.score_field:
518+
if self.knn and field_name == self.knn.score_field_name:
515519
continue
516520
if field_name not in self.model.model_fields: # type: ignore
517521
raise QueryNotSupportedError(
@@ -1307,12 +1311,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
13071311
new_class._meta.primary_key = PrimaryKey(
13081312
name=field_name, field=field
13091313
)
1310-
if getattr(field, "vector_options", None) is not None:
1311-
score_attr = f"_{field_name}_score"
1312-
setattr(new_class, score_attr, None)
1313-
new_class.__annotations__[score_attr] = Union[float, None]
1314-
1315-
new_class.model_config["from_attributes"] = True
13161314

13171315
if not getattr(new_class._meta, "global_key_prefix", None):
13181316
new_class._meta.global_key_prefix = getattr(
@@ -1371,15 +1369,10 @@ def outer_type_or_annotation(field: FieldInfo):
13711369

13721370

13731371
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
1374-
pk: Optional[str] = Field(default=None, primary_key=True)
1372+
pk: Optional[str] = Field(default=None, primary_key=True, validate_default=True)
13751373
Meta = DefaultMeta
13761374

1377-
model_config = ConfigDict(
1378-
from_attributes=True,
1379-
arbitrary_types_allowed=True,
1380-
extra="allow",
1381-
validate_default=True,
1382-
)
1375+
model_config = ConfigDict(from_attributes=True)
13831376

13841377
def __init__(__pydantic_self__, **data: Any) -> None:
13851378
__pydantic_self__.validate_primary_key()
@@ -1518,9 +1511,6 @@ def to_string(s):
15181511
if fields.get("$"):
15191512
json_fields = json.loads(fields.pop("$"))
15201513
doc = cls(**json_fields)
1521-
for k, v in fields.items():
1522-
if k.startswith("__") and k.endswith("_score"):
1523-
setattr(doc, k[1:], float(v))
15241514
else:
15251515
doc = cls(**fields)
15261516

aredis_om/sync_redis.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
11
import redis
2-
3-
4-
__all__ = ["redis"]

tests/test_json_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import pytest
1313
import pytest_asyncio
14-
from pydantic import field_validator
1514

1615
from aredis_om import (
1716
EmbeddedJsonModel,
@@ -1269,15 +1268,18 @@ class Child(Model):
12691268
assert child_validate.bio is None
12701269
assert child_validate.other_name == "Maria"
12711270

1271+
12721272
@py_test_mark_asyncio
12731273
async def test_model_raises_error_if_inherited_from_indexed_model():
12741274
class Model(JsonModel, index=True):
1275-
pass
1275+
pass
12761276

12771277
with pytest.raises(RedisModelError):
1278+
12781279
class Child(Model):
12791280
pass
12801281

1282+
12811283
@py_test_mark_asyncio
12821284
async def test_non_indexed_model_raises_error_on_save():
12831285
class Model(JsonModel):

tests/test_knn_expression.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# type: ignore
2+
import abc
3+
import time
4+
5+
import pytest_asyncio
6+
7+
from aredis_om import Field, JsonModel, KNNExpression, Migrator, VectorFieldOptions
8+
9+
from .conftest import py_test_mark_asyncio
10+
11+
12+
vector_field_options = VectorFieldOptions.flat(
13+
type=VectorFieldOptions.TYPE.FLOAT32,
14+
dimension=768,
15+
distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE,
16+
)
17+
18+
19+
@pytest_asyncio.fixture
20+
async def m(key_prefix, redis):
21+
class BaseJsonModel(JsonModel, abc.ABC):
22+
class Meta:
23+
global_key_prefix = key_prefix
24+
database = redis
25+
26+
class Member(BaseJsonModel, index=True):
27+
name: str
28+
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
29+
embeddings_score: float | None = None
30+
31+
await Migrator().run()
32+
33+
return Member
34+
35+
36+
@pytest_asyncio.fixture
37+
async def embedding_bytes():
38+
return b"\x00" * 3072
39+
40+
41+
@py_test_mark_asyncio
42+
async def test_vector_field(m: type[JsonModel], embedding_bytes):
43+
# Create a new instance of the Member model
44+
member = m(name="seth", embeddings=[[0.1, 0.2, 0.3]])
45+
46+
# Save the member to Redis
47+
mt = await member.save()
48+
49+
assert m.get(mt.pk)
50+
51+
time.sleep(1)
52+
53+
knn = KNNExpression(
54+
k=1,
55+
vector_field=m.embeddings,
56+
score_field=m.embeddings_score,
57+
reference_vector=embedding_bytes,
58+
)
59+
60+
query = m.find()
61+
62+
members = await query.all()
63+
64+
assert len(members) == 1
65+
assert members[0].embeddings_score is not None

tests/test_pydantic_integrations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def test_email_str(m):
5151
)
5252

5353

54-
5554
def test_validator_sets_value_on_init():
5655
value = "bar"
5756

@@ -62,6 +61,6 @@ class ModelWithValidator(HashModel):
6261
def set_field(cls, v):
6362
return value
6463

65-
m = ModelWithValidator()
64+
m = ModelWithValidator(field="foo")
6665

6766
assert m.field == value

0 commit comments

Comments
 (0)