Skip to content

Commit 721d91e

Browse files
committed
Update KNN results include score
1 parent 298f2e1 commit 721d91e

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

aredis_om/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,6 +1510,7 @@ def to_string(s):
15101510
# $ means a json entry
15111511
if fields.get("$"):
15121512
json_fields = json.loads(fields.pop("$"))
1513+
json_fields.update(fields)
15131514
doc = cls(**json_fields)
15141515
else:
15151516
doc = cls(**fields)

tests/test_knn_expression.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# type: ignore
22
import abc
33
import time
4-
import random
4+
import struct
55

66
import pytest_asyncio
77

88
from aredis_om import Field, JsonModel, KNNExpression, Migrator, VectorFieldOptions
99

1010
from .conftest import py_test_mark_asyncio
1111

12+
DIMENSIONS = 768
13+
1214

1315
vector_field_options = VectorFieldOptions.flat(
1416
type=VectorFieldOptions.TYPE.FLOAT32,
15-
dimension=768,
17+
dimension=DIMENSIONS,
1618
distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE,
1719
)
1820

@@ -26,25 +28,25 @@ class Meta:
2628

2729
class Member(BaseJsonModel, index=True):
2830
name: str
29-
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
31+
embeddings: list[list[float]] | bytes = Field(
32+
[], vector_options=vector_field_options
33+
)
3034
embeddings_score: float | None = None
3135

3236
await Migrator().run()
3337

3438
return Member
3539

3640

37-
@pytest_asyncio.fixture
38-
async def embedding_bytes():
39-
return b"\x00" * 3072
41+
def to_bytes(vectors: list[float]) -> bytes:
42+
return struct.pack(f"<{len(vectors)}f", *vectors)
4043

4144

4245
@py_test_mark_asyncio
43-
async def test_vector_field(m: type[JsonModel], embedding_bytes):
46+
async def test_vector_field(m: type[JsonModel]):
4447
# Create a new instance of the Member model
45-
dimensions = m.embeddings.field.vector_options.dimension
46-
embeddings = [random.uniform(-1, 1) for _ in range(dimensions)]
47-
member = m(name="seth", embeddings=[embeddings])
48+
vectors = [0.3 for _ in range(DIMENSIONS)]
49+
member = m(name="seth", embeddings=[vectors])
4850

4951
# Save the member to Redis
5052
mt = await member.save()
@@ -57,7 +59,7 @@ async def test_vector_field(m: type[JsonModel], embedding_bytes):
5759
k=1,
5860
vector_field=m.embeddings,
5961
score_field=m.embeddings_score,
60-
reference_vector=embedding_bytes,
62+
reference_vector=to_bytes(vectors),
6163
)
6264

6365
query = m.find(knn=knn)

0 commit comments

Comments
 (0)