Skip to content

Commit 59f1422

Browse files
committed
feature #197 [Store][Postgres] allow store initialization with utilized distance (DZunke)
This PR was merged into the main branch. Discussion ---------- [Store][Postgres] allow store initialization with utilized distance | Q | A | ------------- | --- | Bug fix? | no | New feature? | yes | Docs? | no | Issues | #195 | License | MIT According to the [pgvector](https://github.com/pgvector/pgvector) documentation there are multiple distance calculations allowed. The current implementation in the store is only the L2 distance with the usage of `<->`. Allowing to utilize the other distance calculation variants would be useful here as mostly the discussion seem to go around the cosine algorithm. Commits ------- a8b4fba [Store][Postgres] allow store initialization with utilized distance
2 parents edf2f33 + a8b4fba commit 59f1422

File tree

3 files changed

+146
-13
lines changed

3 files changed

+146
-13
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Store\Bridge\Postgres;
13+
14+
use OskarStark\Enum\Trait\Comparable;
15+
16+
/**
17+
* @author Denis Zunke <[email protected]>
18+
*/
19+
enum Distance: string
20+
{
21+
use Comparable;
22+
23+
case Cosine = 'cosine';
24+
case InnerProduct = 'inner_product';
25+
case L1 = 'l1';
26+
case L2 = 'l2';
27+
28+
public function getComparisonSign(): string
29+
{
30+
return match ($this) {
31+
self::Cosine => '<=>',
32+
self::InnerProduct => '<#>',
33+
self::L1 => '<+>',
34+
self::L2 => '<->',
35+
};
36+
}
37+
}

src/store/src/Bridge/Postgres/Store.php

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,32 @@ public function __construct(
3434
private \PDO $connection,
3535
private string $tableName,
3636
private string $vectorFieldName = 'embedding',
37+
private Distance $distance = Distance::L2,
3738
) {
3839
}
3940

40-
public static function fromPdo(\PDO $connection, string $tableName, string $vectorFieldName = 'embedding'): self
41-
{
42-
return new self($connection, $tableName, $vectorFieldName);
41+
public static function fromPdo(
42+
\PDO $connection,
43+
string $tableName,
44+
string $vectorFieldName = 'embedding',
45+
Distance $distance = Distance::L2,
46+
): self {
47+
return new self($connection, $tableName, $vectorFieldName, $distance);
4348
}
4449

45-
public static function fromDbal(Connection $connection, string $tableName, string $vectorFieldName = 'embedding'): self
46-
{
50+
public static function fromDbal(
51+
Connection $connection,
52+
string $tableName,
53+
string $vectorFieldName = 'embedding',
54+
Distance $distance = Distance::L2,
55+
): self {
4756
$pdo = $connection->getNativeConnection();
4857

4958
if (!$pdo instanceof \PDO) {
5059
throw new InvalidArgumentException('Only DBAL connections using PDO driver are supported.');
5160
}
5261

53-
return self::fromPdo($pdo, $tableName, $vectorFieldName);
62+
return self::fromPdo($pdo, $tableName, $vectorFieldName, $distance);
5463
}
5564

5665
public function add(VectorDocument ...$documents): void
@@ -84,16 +93,18 @@ public function add(VectorDocument ...$documents): void
8493
*/
8594
public function query(Vector $vector, array $options = [], ?float $minScore = null): array
8695
{
87-
$sql = \sprintf(
88-
'SELECT id, %s AS embedding, metadata, (%s <-> :embedding) AS score
89-
FROM %s
90-
%s
91-
ORDER BY score ASC
92-
LIMIT %d',
96+
$sql = \sprintf(<<<SQL
97+
SELECT id, %s AS embedding, metadata, (%s %s :embedding) AS score
98+
FROM %s
99+
%s
100+
ORDER BY score ASC
101+
LIMIT %d
102+
SQL,
93103
$this->vectorFieldName,
94104
$this->vectorFieldName,
105+
$this->distance->getComparisonSign(),
95106
$this->tableName,
96-
null !== $minScore ? "WHERE ({$this->vectorFieldName} <-> :embedding) >= :minScore" : '',
107+
null !== $minScore ? "WHERE ({$this->vectorFieldName} {$this->distance->getComparisonSign()} :embedding) >= :minScore" : '',
97108
$options['limit'] ?? 5,
98109
);
99110
$statement = $this->connection->prepare($sql);

src/store/tests/Bridge/Postgres/StoreTest.php

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use PHPUnit\Framework\Attributes\CoversClass;
1616
use PHPUnit\Framework\TestCase;
1717
use Symfony\AI\Platform\Vector\Vector;
18+
use Symfony\AI\Store\Bridge\Postgres\Distance;
1819
use Symfony\AI\Store\Bridge\Postgres\Store;
1920
use Symfony\AI\Store\Document\Metadata;
2021
use Symfony\AI\Store\Document\VectorDocument;
@@ -152,6 +153,53 @@ public function testQueryWithoutMinScore()
152153
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
153154
}
154155

156+
public function testQueryChangedDistanceMethodWithoutMinScore()
157+
{
158+
$pdo = $this->createMock(\PDO::class);
159+
$statement = $this->createMock(\PDOStatement::class);
160+
161+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
162+
163+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
164+
FROM embeddings_table
165+
166+
ORDER BY score ASC
167+
LIMIT 5';
168+
169+
$pdo->expects($this->once())
170+
->method('prepare')
171+
->with($this->callback(function ($sql) use ($expectedSql) {
172+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
173+
}))
174+
->willReturn($statement);
175+
176+
$uuid = Uuid::v4();
177+
178+
$statement->expects($this->once())
179+
->method('execute')
180+
->with(['embedding' => '[0.1,0.2,0.3]']);
181+
182+
$statement->expects($this->once())
183+
->method('fetchAll')
184+
->with(\PDO::FETCH_ASSOC)
185+
->willReturn([
186+
[
187+
'id' => $uuid->toRfc4122(),
188+
'embedding' => '[0.1,0.2,0.3]',
189+
'metadata' => json_encode(['title' => 'Test Document']),
190+
'score' => 0.95,
191+
],
192+
]);
193+
194+
$results = $store->query(new Vector([0.1, 0.2, 0.3]));
195+
196+
$this->assertCount(1, $results);
197+
$this->assertInstanceOf(VectorDocument::class, $results[0]);
198+
$this->assertEquals($uuid, $results[0]->id);
199+
$this->assertSame(0.95, $results[0]->score);
200+
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
201+
}
202+
155203
public function testQueryWithMinScore()
156204
{
157205
$pdo = $this->createMock(\PDO::class);
@@ -189,6 +237,43 @@ public function testQueryWithMinScore()
189237
$this->assertCount(0, $results);
190238
}
191239

240+
public function testQueryWithMinScoreAndDifferentDistance()
241+
{
242+
$pdo = $this->createMock(\PDO::class);
243+
$statement = $this->createMock(\PDOStatement::class);
244+
245+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
246+
247+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
248+
FROM embeddings_table
249+
WHERE (embedding <=> :embedding) >= :minScore
250+
ORDER BY score ASC
251+
LIMIT 5';
252+
253+
$pdo->expects($this->once())
254+
->method('prepare')
255+
->with($this->callback(function ($sql) use ($expectedSql) {
256+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
257+
}))
258+
->willReturn($statement);
259+
260+
$statement->expects($this->once())
261+
->method('execute')
262+
->with([
263+
'embedding' => '[0.1,0.2,0.3]',
264+
'minScore' => 0.8,
265+
]);
266+
267+
$statement->expects($this->once())
268+
->method('fetchAll')
269+
->with(\PDO::FETCH_ASSOC)
270+
->willReturn([]);
271+
272+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), [], 0.8);
273+
274+
$this->assertCount(0, $results);
275+
}
276+
192277
public function testQueryWithCustomLimit()
193278
{
194279
$pdo = $this->createMock(\PDO::class);

0 commit comments

Comments
 (0)