Skip to content

Commit bf5a1fe

Browse files
committed
feat(platform): Ollama prompt cache
1 parent 3fc2ab1 commit bf5a1fe

File tree

8 files changed

+180
-10
lines changed

8 files changed

+180
-10
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
14+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
15+
use Symfony\AI\Platform\Message\Message;
16+
use Symfony\AI\Platform\Message\MessageBag;
17+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
18+
19+
require_once dirname(__DIR__).'/bootstrap.php';
20+
21+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client(), cache: new ArrayAdapter());
22+
$model = new Ollama();
23+
24+
$agent = new Agent($platform, $model, logger: logger());
25+
$messages = new MessageBag(
26+
Message::forSystem('You are a helpful assistant.'),
27+
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
28+
);
29+
$result = $agent->call($messages, [
30+
'prompt_cache_key' => 'chat',
31+
]);
32+
33+
echo $result->getContent().\PHP_EOL;
34+
35+
$secondResult = $agent->call($messages, [
36+
'prompt_cache_key' => 'chat',
37+
]);
38+
39+
echo $result->getContent().\PHP_EOL;

src/ai-bundle/config/options.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
->arrayNode('ollama')
8080
->children()
8181
->scalarNode('host_url')->defaultValue('http://127.0.0.1:11434')->end()
82+
->scalarNode('cache')->end()
8283
->end()
8384
->end()
8485
->arrayNode('cerebras')

src/ai-bundle/src/AiBundle.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,23 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
378378
}
379379

380380
if ('ollama' === $type) {
381+
$arguments = [
382+
$platform['host_url'],
383+
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
384+
new Reference('ai.platform.contract.ollama'),
385+
];
386+
387+
if (\array_key_exists('cache', $platform)) {
388+
$arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE);
389+
}
390+
381391
$platformId = 'ai.platform.ollama';
382392
$definition = (new Definition(Platform::class))
383393
->setFactory(MistralPlatformFactory::class.'::create')
384394
->setFactory(OllamaPlatformFactory::class.'::create')
385395
->setLazy(true)
386396
->addTag('proxy', ['interface' => PlatformInterface::class])
387-
->setArguments([
388-
$platform['host_url'],
389-
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
390-
new Reference('ai.platform.contract.ollama'),
391-
])
397+
->setArguments($arguments)
392398
->addTag('ai.platform');
393399

394400
$container->setDefinition($platformId, $definition);

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,39 @@ public function testConfigurationWithUseAttributeAsKeyWorksWithoutNormalizeKeys(
303303
$this->assertTrue($container->hasDefinition('ai.store.mongodb.Production_DB-v3'));
304304
}
305305

306+
public function testPromptCachingCanBeUsedWithOllama()
307+
{
308+
$container = $this->buildContainer([
309+
'ai' => [
310+
'platform' => [
311+
'ollama' => [
312+
'host_url' => 'http://127.0.0.1:11434',
313+
'cache' => 'cache.app',
314+
],
315+
],
316+
],
317+
]);
318+
319+
$this->assertTrue($container->hasDefinition('ai.platform.ollama'));
320+
321+
$definition = $container->getDefinition('ai.platform.ollama');
322+
$this->assertCount(4, $definition->getArguments());
323+
324+
$this->assertSame('http://127.0.0.1:11434', $definition->getArgument(0));
325+
326+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
327+
$httpClientArgument = $definition->getArgument(1);
328+
$this->assertSame('http_client', (string) $httpClientArgument);
329+
330+
$this->assertInstanceOf(Reference::class, $definition->getArgument(2));
331+
$contractArgument = $definition->getArgument(2);
332+
$this->assertSame('ai.platform.contract.ollama', (string) $contractArgument);
333+
334+
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
335+
$cacheArgument = $definition->getArgument(3);
336+
$this->assertSame('cache.app', (string) $cacheArgument);
337+
}
338+
306339
private function buildContainer(array $configuration): ContainerBuilder
307340
{
308341
$container = new ContainerBuilder();

src/platform/composer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"phpstan/phpstan-symfony": "^2.0.6",
4343
"phpunit/phpunit": "^11.5",
4444
"symfony/ai-agent": "@dev",
45+
"symfony/cache": "^6.4 || ^7.1",
4546
"symfony/console": "^6.4 || ^7.1",
4647
"symfony/dotenv": "^6.4 || ^7.1",
4748
"symfony/event-dispatcher": "^6.4 || ^7.1",

src/platform/src/Bridge/Ollama/OllamaClient.php

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\ModelClientInterface;
1717
use Symfony\AI\Platform\Result\RawHttpResult;
18+
use Symfony\Component\HttpClient\Response\MockResponse;
19+
use Symfony\Contracts\Cache\CacheInterface;
1820
use Symfony\Contracts\HttpClient\HttpClientInterface;
21+
use Symfony\Contracts\HttpClient\ResponseInterface;
1922

2023
/**
2124
* @author Christopher Hertel <[email protected]>
@@ -25,6 +28,7 @@
2528
public function __construct(
2629
private HttpClientInterface $httpClient,
2730
private string $hostUrl,
31+
private ?CacheInterface $cache = null,
2832
) {
2933
}
3034

@@ -68,10 +72,40 @@ private function doCompletionRequest(array|string $payload, array $options = [])
6872
unset($options['response_format']);
6973
}
7074

71-
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
72-
'headers' => ['Content-Type' => 'application/json'],
73-
'json' => array_merge($options, $payload),
74-
]));
75+
$requestCallback = fn ($options, $payload): ResponseInterface => $this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
76+
'headers' => [
77+
'Content-Type' => 'application/json',
78+
],
79+
'json' => [
80+
...$options,
81+
...$payload,
82+
],
83+
]);
84+
85+
if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && $options['prompt_cache_key'])) {
86+
$cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5(\is_array($payload) ? json_encode($payload) : ['context' => $payload]));
87+
88+
unset($options['prompt_cache_key']);
89+
90+
$cachedResponse = $this->cache->get($cacheKey, static function () use ($requestCallback, $options, $payload): array {
91+
$response = $requestCallback($options, $payload);
92+
93+
return [
94+
'content' => $response->getContent(),
95+
'headers' => $response->getHeaders(),
96+
'http_code' => $response->getStatusCode(),
97+
];
98+
});
99+
100+
$mockedResponse = new MockResponse($cachedResponse['content'], [
101+
'http_code' => $cachedResponse['http_code'],
102+
'response_headers' => $cachedResponse['headers'],
103+
]);
104+
105+
return new RawHttpResult(MockResponse::fromRequest('POST', \sprintf('%s/api/chat', $this->hostUrl), $options, $mockedResponse));
106+
}
107+
108+
return new RawHttpResult($requestCallback($options, $payload));
75109
}
76110

77111
/**

src/platform/src/Bridge/Ollama/PlatformFactory.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use Symfony\AI\Platform\Contract;
1616
use Symfony\AI\Platform\Platform;
1717
use Symfony\Component\HttpClient\EventSourceHttpClient;
18+
use Symfony\Contracts\Cache\CacheInterface;
1819
use Symfony\Contracts\HttpClient\HttpClientInterface;
1920

2021
/**
@@ -26,11 +27,12 @@ public static function create(
2627
string $hostUrl = 'http://localhost:11434',
2728
?HttpClientInterface $httpClient = null,
2829
?Contract $contract = null,
30+
?CacheInterface $cache = null,
2931
): Platform {
3032
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
3133

3234
return new Platform(
33-
[new OllamaClient($httpClient, $hostUrl)],
35+
[new OllamaClient($httpClient, $hostUrl, $cache)],
3436
[new OllamaResultConverter()],
3537
$contract ?? OllamaContract::create()
3638
);

src/platform/tests/Bridge/Ollama/OllamaClientTest.php

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
2020
use Symfony\AI\Platform\Model;
2121
use Symfony\AI\Platform\Result\StreamResult;
22+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
2223
use Symfony\Component\HttpClient\MockHttpClient;
2324
use Symfony\Component\HttpClient\Response\JsonMockResponse;
2425
use Symfony\Component\HttpClient\Response\MockResponse;
@@ -175,4 +176,57 @@ public function testStreamingConverterWithDirectResponse()
175176

176177
$this->assertNotInstanceOf(StreamResult::class, $regularResult);
177178
}
179+
180+
public function testPromptCachingIsSupported()
181+
{
182+
$httpClient = new MockHttpClient([
183+
new JsonMockResponse([
184+
'capabilities' => ['completion'],
185+
]),
186+
new JsonMockResponse([
187+
'model' => 'llama3.2',
188+
'created_at' => '2025-08-23T10:00:01Z',
189+
'message' => ['role' => 'assistant', 'content' => 'Hello world'],
190+
'done' => true,
191+
]),
192+
new JsonMockResponse([
193+
'capabilities' => ['completion'],
194+
]),
195+
]);
196+
197+
$platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter());
198+
199+
$firstCall = $platform->invoke(new Ollama(), [
200+
'messages' => [
201+
[
202+
'role' => 'user',
203+
'content' => 'Say hello world',
204+
],
205+
],
206+
'model' => 'llama3.2',
207+
], [
208+
'prompt_cache_key' => 'foo',
209+
]);
210+
211+
$result = $firstCall->getResult();
212+
213+
$this->assertSame('Hello world', $result->getContent());
214+
215+
$secondCall = $platform->invoke(new Ollama(), [
216+
'messages' => [
217+
[
218+
'role' => 'user',
219+
'content' => 'Say hello world',
220+
],
221+
],
222+
'model' => 'llama3.2',
223+
], [
224+
'prompt_cache_key' => 'foo',
225+
]);
226+
227+
$result = $secondCall->getResult();
228+
229+
$this->assertSame('Hello world', $result->getContent());
230+
$this->assertSame(3, $httpClient->getRequestsCount());
231+
}
178232
}

0 commit comments

Comments
 (0)