Skip to content

Commit dd03c20

Browse files
committed
feature #234 [Platform] Add Ollama embeddings (Guikingone)
This PR was squashed before being merged into the main branch. Discussion ---------- [Platform] Add Ollama embeddings | Q | A | ------------- | --- | Bug fix? | no | New feature? | yes | Docs? | yes | Issues | None | License | MIT Hi 👋🏻 This PR aims to add the support for embeddings via `Ollama`, the platform is updated to handle the new embedding client. Commits ------- 005ef2f [Platform] Add Ollama embeddings
2 parents 655244e + 005ef2f commit dd03c20

File tree

8 files changed

+198
-46
lines changed

8 files changed

+198
-46
lines changed

examples/ollama/embeddings.php

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
13+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
14+
15+
require_once dirname(__DIR__).'/bootstrap.php';
16+
17+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client());
18+
19+
$response = $platform->invoke(new Ollama(Ollama::NOMIC_EMBED_TEXT), <<<TEXT
20+
Once upon a time, there was a country called Japan. It was a beautiful country with a lot of mountains and rivers.
21+
The people of Japan were very kind and hardworking. They loved their country very much and took care of it. The
22+
country was very peaceful and prosperous. The people lived happily ever after.
23+
TEXT);
24+
25+
echo 'Dimensions: '.$response->asVectors()[0]->getDimensions().\PHP_EOL;

src/platform/src/Bridge/Ollama/Ollama.php

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class Ollama extends Model
3737
public const QWEN = 'qwen';
3838
public const QWEN_2 = 'qwen2';
3939
public const LLAMA_2 = 'llama2';
40+
public const NOMIC_EMBED_TEXT = 'nomic-embed-text';
41+
public const BGE_M3 = 'bge-m3';
42+
public const ALL_MINILM = 'all-minilm';
4043

4144
private const TOOL_PATTERNS = [
4245
'/./' => [
@@ -52,6 +55,9 @@ class Ollama extends Model
5255
'/^(deepseek|mistral)/' => [
5356
Capability::TOOL_CALLING,
5457
],
58+
'/^(nomic|bge|all-minilm).*/' => [
59+
Capability::INPUT_MULTIPLE,
60+
],
5561
];
5662

5763
/**
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Bridge\Ollama;
13+
14+
use Symfony\AI\Platform\Exception\InvalidArgumentException;
15+
use Symfony\AI\Platform\Model;
16+
use Symfony\AI\Platform\ModelClientInterface;
17+
use Symfony\AI\Platform\Result\RawHttpResult;
18+
use Symfony\Contracts\HttpClient\HttpClientInterface;
19+
20+
/**
21+
* @author Christopher Hertel <[email protected]>
22+
*/
23+
final readonly class OllamaClient implements ModelClientInterface
24+
{
25+
public function __construct(
26+
private HttpClientInterface $httpClient,
27+
private string $hostUrl,
28+
) {
29+
}
30+
31+
public function supports(Model $model): bool
32+
{
33+
return $model instanceof Ollama;
34+
}
35+
36+
public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
37+
{
38+
$response = $this->httpClient->request('POST', \sprintf('%s/api/show', $this->hostUrl), [
39+
'json' => [
40+
'model' => $model->getName(),
41+
],
42+
]);
43+
44+
$modelInformationsPayload = $response->toArray();
45+
46+
return match (true) {
47+
\in_array('completion', $modelInformationsPayload['capabilities'], true) => $this->doCompletionRequest($payload, $options),
48+
\in_array('embedding', $modelInformationsPayload['capabilities'], true) => $this->doEmbeddingsRequest($model, $payload, $options),
49+
default => throw new InvalidArgumentException(\sprintf('Unsupported model "%s".', $model::class)),
50+
};
51+
}
52+
53+
/**
54+
* @param array<string|int, mixed> $payload
55+
* @param array<string, mixed> $options
56+
*/
57+
private function doCompletionRequest(array|string $payload, array $options = []): RawHttpResult
58+
{
59+
// Revert Ollama's default streaming behavior
60+
$options['stream'] ??= false;
61+
62+
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
63+
'headers' => ['Content-Type' => 'application/json'],
64+
'json' => array_merge($options, $payload),
65+
]));
66+
}
67+
68+
/**
69+
* @param array<string|int, mixed> $payload
70+
* @param array<string, mixed> $options
71+
*/
72+
public function doEmbeddingsRequest(Model $model, array|string $payload, array $options = []): RawHttpResult
73+
{
74+
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/embed', $this->hostUrl), [
75+
'json' => array_merge($options, [
76+
'model' => $model->getName(),
77+
'input' => $payload,
78+
]),
79+
]));
80+
}
81+
}

src/platform/src/Bridge/Ollama/OllamaModelClient.php

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/platform/src/Bridge/Ollama/OllamaResultConverter.php

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
use Symfony\AI\Platform\Result\TextResult;
1919
use Symfony\AI\Platform\Result\ToolCall;
2020
use Symfony\AI\Platform\Result\ToolCallResult;
21+
use Symfony\AI\Platform\Result\VectorResult;
2122
use Symfony\AI\Platform\ResultConverterInterface;
23+
use Symfony\AI\Platform\Vector\Vector;
2224

2325
/**
2426
* @author Christopher Hertel <[email protected]>
@@ -34,6 +36,16 @@ public function convert(RawResultInterface $result, array $options = []): Result
3436
{
3537
$data = $result->getData();
3638

39+
return \array_key_exists('embeddings', $data)
40+
? $this->doConvertEmbeddings($data)
41+
: $this->doConvertCompletion($data);
42+
}
43+
44+
/**
45+
* @param array<string, mixed> $data
46+
*/
47+
public function doConvertCompletion(array $data): ResultInterface
48+
{
3749
if (!isset($data['message'])) {
3850
throw new RuntimeException('Response does not contain message.');
3951
}
@@ -54,4 +66,21 @@ public function convert(RawResultInterface $result, array $options = []): Result
5466

5567
return new TextResult($data['message']['content']);
5668
}
69+
70+
/**
71+
* @param array<string, mixed> $data
72+
*/
73+
public function doConvertEmbeddings(array $data): ResultInterface
74+
{
75+
if ([] === $data['embeddings']) {
76+
throw new RuntimeException('Response does not contain embeddings.');
77+
}
78+
79+
return new VectorResult(
80+
...array_map(
81+
static fn (array $embedding): Vector => new Vector($embedding),
82+
$data['embeddings'],
83+
),
84+
);
85+
}
5786
}

src/platform/src/Bridge/Ollama/PlatformFactory.php

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ public static function create(
2929
): Platform {
3030
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
3131

32-
return new Platform([new OllamaModelClient($httpClient, $hostUrl)], [new OllamaResultConverter()], $contract ?? OllamaContract::create());
32+
return new Platform(
33+
[new OllamaClient($httpClient, $hostUrl)],
34+
[new OllamaResultConverter()],
35+
$contract ?? OllamaContract::create()
36+
);
3337
}
3438
}

src/platform/tests/Bridge/Ollama/OllamaResultConverterTest.php

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
use Symfony\AI\Platform\Exception\RuntimeException;
2121
use Symfony\AI\Platform\Model;
2222
use Symfony\AI\Platform\Result\InMemoryRawResult;
23+
use Symfony\AI\Platform\Result\RawHttpResult;
2324
use Symfony\AI\Platform\Result\TextResult;
2425
use Symfony\AI\Platform\Result\ToolCall;
2526
use Symfony\AI\Platform\Result\ToolCallResult;
27+
use Symfony\AI\Platform\Result\VectorResult;
28+
use Symfony\AI\Platform\Vector\Vector;
29+
use Symfony\Contracts\HttpClient\ResponseInterface;
2630

2731
#[CoversClass(OllamaResultConverter::class)]
2832
#[Small]
2933
#[UsesClass(Ollama::class)]
3034
#[UsesClass(TextResult::class)]
3135
#[UsesClass(ToolCall::class)]
3236
#[UsesClass(ToolCallResult::class)]
37+
#[UsesClass(Vector::class)]
38+
#[UsesClass(VectorResult::class)]
3339
final class OllamaResultConverterTest extends TestCase
3440
{
3541
public function testSupportsLlamaModel()
@@ -143,4 +149,29 @@ public function testThrowsExceptionWhenNoContent()
143149

144150
$converter->convert($rawResult);
145151
}
152+
153+
public function testItConvertsAResponseToAVectorResult()
154+
{
155+
$result = $this->createStub(ResponseInterface::class);
156+
$result
157+
->method('toArray')
158+
->willReturn([
159+
'model' => 'all-minilm',
160+
'embeddings' => [
161+
[0.3, 0.4, 0.4],
162+
[0.0, 0.0, 0.2],
163+
],
164+
'total_duration' => 14143917,
165+
'load_duration' => 1019500,
166+
'prompt_eval_count' => 8,
167+
]);
168+
169+
$vectorResult = (new OllamaResultConverter())->convert(new RawHttpResult($result));
170+
$convertedContent = $vectorResult->getContent();
171+
172+
$this->assertCount(2, $convertedContent);
173+
174+
$this->assertSame([0.3, 0.4, 0.4], $convertedContent[0]->getData());
175+
$this->assertSame([0.0, 0.0, 0.2], $convertedContent[1]->getData());
176+
}
146177
}

src/platform/tests/Bridge/Ollama/OllamaTest.php

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ public function testModelsWithoutToolCallingCapability(string $modelName)
4444
);
4545
}
4646

47+
#[DataProvider('provideModelsWithMultipleInputCapabilities')]
48+
public function testModelsWithMultipleInputCapabilities(string $modelName)
49+
{
50+
$model = new Ollama($modelName);
51+
52+
$this->assertTrue(
53+
$model->supports(Capability::INPUT_MULTIPLE),
54+
\sprintf('Model "%s" should not support multiple input capabilities', $modelName)
55+
);
56+
}
57+
4758
/**
4859
* @return iterable<array{string}>
4960
*/
@@ -82,4 +93,14 @@ public static function provideModelsWithoutToolCallingCapability(): iterable
8293
yield 'llava' => [Ollama::LLAVA];
8394
yield 'qwen2.5vl' => [Ollama::QWEN_2_5_VL]; // This has 'vl' suffix which doesn't match the pattern
8495
}
96+
97+
/**
98+
* @return iterable<array{string}>
99+
*/
100+
public static function provideModelsWithMultipleInputCapabilities(): iterable
101+
{
102+
yield 'nomic-embed-text' => [Ollama::NOMIC_EMBED_TEXT];
103+
yield 'bge-m3' => [Ollama::BGE_M3];
104+
yield 'all-minilm' => [Ollama::ALL_MINILM];
105+
}
85106
}

0 commit comments

Comments
 (0)