Skip to content

Commit 26c6e68

Browse files
authored
Fix InnerProductSimilarity queries (ydb-platform/ydb#19584) (#15)
1 parent bb13749 commit 26c6e68

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/fake_embeddings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ class FakeEmbeddings(Embeddings):
99
def embed_documents(self, texts: List[str]) -> List[List[float]]:
1010
"""Return simple embeddings.
1111
Embeddings encode each text as its index."""
12-
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
12+
return [[float(0.0)] * (i-1) + [float(1.0)] +
13+
[float(0.0)] * (9-i) for i in range(len(texts))]
1314

1415
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
1516
return self.embed_documents(texts)
@@ -19,7 +20,7 @@ def embed_query(self, text: str) -> List[float]:
1920
Embeddings are identical to embed_documents(texts)[0].
2021
Distance to each text will be that text's index,
2122
as it was passed to embed_documents."""
22-
return [float(1.0)] * 9 + [float(0.0)]
23+
return [float(1.0)] + [float(0.0)] * 9
2324

2425
async def aembed_query(self, text: str) -> List[float]:
2526
return self.embed_query(text)
@@ -39,9 +40,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
3940
for text in texts:
4041
if text not in self.known_texts:
4142
self.known_texts.append(text)
42-
vector = [float(1.0)] * (self.dimensionality - 1) + [
43-
float(self.known_texts.index(text))
44-
]
43+
vector = [float(0.0)] * self.dimensionality
44+
vector[self.known_texts.index(text) % self.dimensionality] = float(1.0)
4545
out_vectors.append(vector)
4646
return out_vectors
4747

0 commit comments

Comments
 (0)