Skip to content

Commit 10a69f5

Browse files
RamyHakamchr-hertel
authored andcommitted
[Store][InMemory] Fix cosine similarity sorting
1 parent d508af5 commit 10a69f5

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/store/src/InMemoryStore.php

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*/
2121
final class InMemoryStore implements VectorStoreInterface
2222
{
23-
public const COSINE_SIMILARITY = 'cosine';
23+
public const COSINE_DISTANCE = 'cosine';
2424
public const ANGULAR_DISTANCE = 'angular';
2525
public const EUCLIDEAN_DISTANCE = 'euclidean';
2626
public const MANHATTAN_DISTANCE = 'manhattan';
@@ -32,7 +32,7 @@ final class InMemoryStore implements VectorStoreInterface
3232
private array $documents = [];
3333

3434
public function __construct(
35-
private readonly string $similarity = self::COSINE_SIMILARITY,
35+
private readonly string $distance = self::COSINE_DISTANCE,
3636
) {
3737
}
3838

@@ -48,13 +48,13 @@ public function add(VectorDocument ...$documents): void
4848
*/
4949
public function query(Vector $vector, array $options = [], ?float $minScore = null): array
5050
{
51-
$strategy = match ($this->similarity) {
52-
self::COSINE_SIMILARITY => $this->cosineSimilarity(...),
51+
$strategy = match ($this->distance) {
52+
self::COSINE_DISTANCE => $this->cosineDistance(...),
5353
self::ANGULAR_DISTANCE => $this->angularDistance(...),
5454
self::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...),
5555
self::MANHATTAN_DISTANCE => $this->manhattanDistance(...),
5656
self::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...),
57-
default => throw new InvalidArgumentException(\sprintf('Unsupported similarity strategy "%s"', $this->similarity)),
57+
default => throw new InvalidArgumentException(\sprintf('Unsupported distance metric "%s"', $this->distance)),
5858
};
5959

6060
$currentEmbeddings = array_map(
@@ -80,6 +80,11 @@ public function query(Vector $vector, array $options = [], ?float $minScore = nu
8080
);
8181
}
8282

83+
private function cosineDistance(VectorDocument $embedding, Vector $against): float
84+
{
85+
return 1 - $this->cosineSimilarity($embedding, $against);
86+
}
87+
8388
private function cosineSimilarity(VectorDocument $embedding, Vector $against): float
8489
{
8590
$currentEmbeddingVectors = $embedding->vector->getData();

src/store/tests/InMemoryStoreTest.php

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
final class InMemoryStoreTest extends TestCase
2424
{
2525
#[Test]
26-
public function storeCanSearchUsingCosineSimilarity(): void
26+
public function storeCanSearchUsingCosineDistance(): void
2727
{
2828
$store = new InMemoryStore();
2929
$store->add(
@@ -32,19 +32,44 @@ public function storeCanSearchUsingCosineSimilarity(): void
3232
new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])),
3333
);
3434

35-
self::assertCount(3, $store->query(new Vector([0.0, 0.1, 0.6])));
35+
$result = $store->query(new Vector([0.0, 0.1, 0.6]));
36+
self::assertCount(3, $result);
37+
self::assertSame([0.1, 0.1, 0.5], $result[0]->vector->getData());
3638

3739
$store->add(
3840
new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])),
3941
new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])),
4042
new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])),
4143
);
4244

43-
self::assertCount(6, $store->query(new Vector([0.0, 0.1, 0.6])));
45+
$result = $store->query(new Vector([0.0, 0.1, 0.6]));
46+
self::assertCount(6, $result);
47+
self::assertSame([0.1, 0.1, 0.5], $result[0]->vector->getData());
4448
}
4549

4650
#[Test]
47-
public function storeCanSearchUsingCosineSimilarityWithMaxItems(): void
51+
public function storeCanSearchUsingCosineDistanceAndReturnCorrectOrder(): void
52+
{
53+
$store = new InMemoryStore();
54+
$store->add(
55+
new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])),
56+
new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])),
57+
new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])),
58+
new VectorDocument(Uuid::v4(), new Vector([0.3, 0.1, 0.6])),
59+
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.1, 0.6])),
60+
);
61+
62+
$result = $store->query(new Vector([0.0, 0.1, 0.6]));
63+
self::assertCount(5, $result);
64+
self::assertSame([0.0, 0.1, 0.6], $result[0]->vector->getData());
65+
self::assertSame([0.1, 0.1, 0.5], $result[1]->vector->getData());
66+
self::assertSame([0.3, 0.1, 0.6], $result[2]->vector->getData());
67+
self::assertSame([0.3, 0.7, 0.1], $result[3]->vector->getData());
68+
self::assertSame([0.7, -0.3, 0.0], $result[4]->vector->getData());
69+
}
70+
71+
#[Test]
72+
public function storeCanSearchUsingCosineDistanceWithMaxItems(): void
4873
{
4974
$store = new InMemoryStore();
5075
$store->add(

0 commit comments

Comments
 (0)