Skip to content

Commit 914eddf

Browse files
committed
feat(platform): Ollama prompt cache
1 parent d8dfb30 commit 914eddf

File tree

9 files changed

+212
-13
lines changed

9 files changed

+212
-13
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
14+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
15+
use Symfony\AI\Platform\Message\Message;
16+
use Symfony\AI\Platform\Message\MessageBag;
17+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
18+
19+
require_once dirname(__DIR__).'/bootstrap.php';
20+
21+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client(), cache: new ArrayAdapter());
22+
$model = new Ollama('llama');
23+
24+
$agent = new Agent($platform, $model, logger: logger());
25+
$messages = new MessageBag(
26+
Message::forSystem('You are a helpful assistant.'),
27+
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
28+
);
29+
$result = $agent->call($messages, [
30+
'prompt_cache_key' => 'chat',
31+
]);
32+
33+
echo $result->getContent().\PHP_EOL;
34+
35+
assert($result->getMetadata()->get('cached'));
36+
assert('chat' === $result->getMetadata()->get('prompt_cache_key'));
37+
assert(0 !== $result->getMetadata()->get('cached_prompt_count'));
38+
assert(0 !== $result->getMetadata()->get('cached_completion_count'));
39+
40+
$secondResult = $agent->call($messages, [
41+
'prompt_cache_key' => 'chat',
42+
]);
43+
44+
echo $secondResult->getContent().\PHP_EOL;
45+
46+
assert($secondResult->getMetadata()->get('cached'));
47+
assert('chat' === $secondResult->getMetadata()->get('prompt_cache_key'));
48+
assert(0 !== $result->getMetadata()->get('cached_prompt_count'));
49+
assert(0 !== $result->getMetadata()->get('cached_completion_count'));

src/ai-bundle/config/options.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
->arrayNode('ollama')
9090
->children()
9191
->scalarNode('host_url')->defaultValue('http://127.0.0.1:11434')->end()
92+
->scalarNode('cache')->end()
9293
->end()
9394
->end()
9495
->arrayNode('cerebras')

src/ai-bundle/src/AiBundle.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,23 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
402402
}
403403

404404
if ('ollama' === $type) {
405+
$arguments = [
406+
$platform['host_url'],
407+
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
408+
new Reference('ai.platform.contract.ollama'),
409+
];
410+
411+
if (\array_key_exists('cache', $platform)) {
412+
$arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE);
413+
}
414+
405415
$platformId = 'ai.platform.ollama';
406416
$definition = (new Definition(Platform::class))
407417
->setFactory(MistralPlatformFactory::class.'::create')
408418
->setFactory(OllamaPlatformFactory::class.'::create')
409419
->setLazy(true)
410420
->addTag('proxy', ['interface' => PlatformInterface::class])
411-
->setArguments([
412-
$platform['host_url'],
413-
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
414-
new Reference('ai.platform.contract.ollama'),
415-
])
421+
->setArguments($arguments)
416422
->addTag('ai.platform');
417423

418424
$container->setDefinition($platformId, $definition);

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,39 @@ public function testIndexerWithSourceAndTransformers()
11851185
$this->assertSame(TextTrimTransformer::class, (string) $arguments[4][0]);
11861186
}
11871187

1188+
public function testPromptCachingCanBeUsedWithOllama()
1189+
{
1190+
$container = $this->buildContainer([
1191+
'ai' => [
1192+
'platform' => [
1193+
'ollama' => [
1194+
'host_url' => 'http://127.0.0.1:11434',
1195+
'cache' => 'cache.app',
1196+
],
1197+
],
1198+
],
1199+
]);
1200+
1201+
$this->assertTrue($container->hasDefinition('ai.platform.ollama'));
1202+
1203+
$definition = $container->getDefinition('ai.platform.ollama');
1204+
$this->assertCount(4, $definition->getArguments());
1205+
1206+
$this->assertSame('http://127.0.0.1:11434', $definition->getArgument(0));
1207+
1208+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
1209+
$httpClientArgument = $definition->getArgument(1);
1210+
$this->assertSame('http_client', (string) $httpClientArgument);
1211+
1212+
$this->assertInstanceOf(Reference::class, $definition->getArgument(2));
1213+
$contractArgument = $definition->getArgument(2);
1214+
$this->assertSame('ai.platform.contract.ollama', (string) $contractArgument);
1215+
1216+
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
1217+
$cacheArgument = $definition->getArgument(3);
1218+
$this->assertSame('cache.app', (string) $cacheArgument);
1219+
}
1220+
11881221
private function buildContainer(array $configuration): ContainerBuilder
11891222
{
11901223
$container = new ContainerBuilder();

src/platform/composer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"phpstan/phpstan-symfony": "^2.0.6",
6060
"phpunit/phpunit": "^11.5",
6161
"symfony/ai-agent": "@dev",
62+
"symfony/cache": "^6.4 || ^7.1",
6263
"symfony/console": "^6.4 || ^7.1",
6364
"symfony/dotenv": "^6.4 || ^7.1",
6465
"symfony/event-dispatcher": "^6.4 || ^7.1",

src/platform/src/Bridge/Ollama/OllamaClient.php

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\ModelClientInterface;
1717
use Symfony\AI\Platform\Result\RawHttpResult;
18+
use Symfony\Component\HttpClient\Response\MockResponse;
19+
use Symfony\Contracts\Cache\CacheInterface;
1820
use Symfony\Contracts\HttpClient\HttpClientInterface;
21+
use Symfony\Contracts\HttpClient\ResponseInterface;
1922

2023
/**
2124
* @author Christopher Hertel <[email protected]>
@@ -25,6 +28,7 @@
2528
public function __construct(
2629
private HttpClientInterface $httpClient,
2730
private string $hostUrl,
31+
private ?CacheInterface $cache = null,
2832
) {
2933
}
3034

@@ -68,10 +72,40 @@ private function doCompletionRequest(array|string $payload, array $options = [])
6872
unset($options['response_format']);
6973
}
7074

71-
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
72-
'headers' => ['Content-Type' => 'application/json'],
73-
'json' => array_merge($options, $payload),
74-
]));
75+
$requestCallback = fn ($options, $payload): ResponseInterface => $this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
76+
'headers' => [
77+
'Content-Type' => 'application/json',
78+
],
79+
'json' => [
80+
...$options,
81+
...$payload,
82+
],
83+
]);
84+
85+
if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && '' !== $options['prompt_cache_key'])) {
86+
$cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5(\is_array($payload) ? json_encode($payload) : ['context' => $payload]));
87+
88+
unset($options['prompt_cache_key']);
89+
90+
$cachedResponse = $this->cache->get($cacheKey, static function () use ($requestCallback, $options, $payload): array {
91+
$response = $requestCallback($options, $payload);
92+
93+
return [
94+
'content' => $response->getContent(),
95+
'headers' => $response->getHeaders(),
96+
'http_code' => $response->getStatusCode(),
97+
];
98+
});
99+
100+
$mockedResponse = new MockResponse($cachedResponse['content'], [
101+
'http_code' => $cachedResponse['http_code'],
102+
'response_headers' => $cachedResponse['headers'],
103+
]);
104+
105+
return new RawHttpResult(MockResponse::fromRequest('POST', \sprintf('%s/api/chat', $this->hostUrl), $options, $mockedResponse));
106+
}
107+
108+
return new RawHttpResult($requestCallback($options, $payload));
75109
}
76110

77111
/**

src/platform/src/Bridge/Ollama/OllamaResultConverter.php

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ public function convert(RawResultInterface $result, array $options = []): Result
4747

4848
return \array_key_exists('embeddings', $data)
4949
? $this->doConvertEmbeddings($data)
50-
: $this->doConvertCompletion($data);
50+
: $this->doConvertCompletion($data, $options);
5151
}
5252

5353
/**
5454
* @param array<string, mixed> $data
55+
* @param array<string, mixed> $options
5556
*/
56-
public function doConvertCompletion(array $data): ResultInterface
57+
public function doConvertCompletion(array $data, array $options): ResultInterface
5758
{
5859
if (!isset($data['message'])) {
5960
throw new RuntimeException('Response does not contain message.');
@@ -73,7 +74,18 @@ public function doConvertCompletion(array $data): ResultInterface
7374
return new ToolCallResult(...$toolCalls);
7475
}
7576

76-
return new TextResult($data['message']['content']);
77+
$result = new TextResult($data['message']['content']);
78+
79+
if (\array_key_exists('prompt_cache_key', $options)) {
80+
$metadata = $result->getMetadata();
81+
82+
$metadata->add('cached', true);
83+
$metadata->add('prompt_cache_key', $options['prompt_cache_key']);
84+
$metadata->add('cached_prompt_count', $data['prompt_eval_count']);
85+
$metadata->add('cached_completion_count', $data['eval_count']);
86+
}
87+
88+
return $result;
7789
}
7890

7991
/**

src/platform/src/Bridge/Ollama/PlatformFactory.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use Symfony\AI\Platform\Contract;
1616
use Symfony\AI\Platform\Platform;
1717
use Symfony\Component\HttpClient\EventSourceHttpClient;
18+
use Symfony\Contracts\Cache\CacheInterface;
1819
use Symfony\Contracts\HttpClient\HttpClientInterface;
1920

2021
/**
@@ -26,11 +27,12 @@ public static function create(
2627
string $hostUrl = 'http://localhost:11434',
2728
?HttpClientInterface $httpClient = null,
2829
?Contract $contract = null,
30+
?CacheInterface $cache = null,
2931
): Platform {
3032
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
3133

3234
return new Platform(
33-
[new OllamaClient($httpClient, $hostUrl)],
35+
[new OllamaClient($httpClient, $hostUrl, $cache)],
3436
[new OllamaResultConverter()],
3537
$contract ?? OllamaContract::create()
3638
);

src/platform/tests/Bridge/Ollama/OllamaClientTest.php

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
2020
use Symfony\AI\Platform\Model;
2121
use Symfony\AI\Platform\Result\StreamResult;
22+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
2223
use Symfony\Component\HttpClient\MockHttpClient;
2324
use Symfony\Component\HttpClient\Response\JsonMockResponse;
2425
use Symfony\Component\HttpClient\Response\MockResponse;
@@ -175,4 +176,64 @@ public function testStreamingConverterWithDirectResponse()
175176

176177
$this->assertNotInstanceOf(StreamResult::class, $regularResult);
177178
}
179+
180+
public function testPromptCachingIsSupported()
181+
{
182+
$httpClient = new MockHttpClient([
183+
new JsonMockResponse([
184+
'capabilities' => ['completion'],
185+
]),
186+
new JsonMockResponse([
187+
'model' => 'llama3.2',
188+
'created_at' => '2025-08-23T10:00:01Z',
189+
'message' => ['role' => 'assistant', 'content' => 'Hello world'],
190+
'prompt_eval_count' => 10,
191+
'eval_count' => 10,
192+
'done' => true,
193+
]),
194+
new JsonMockResponse([
195+
'capabilities' => ['completion'],
196+
]),
197+
]);
198+
199+
$platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter());
200+
201+
$firstCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [
202+
'messages' => [
203+
[
204+
'role' => 'user',
205+
'content' => 'Say hello world',
206+
],
207+
],
208+
'model' => 'llama3.2',
209+
], [
210+
'prompt_cache_key' => 'foo',
211+
]);
212+
213+
$result = $firstCall->getResult();
214+
215+
$this->assertSame('Hello world', $result->getContent());
216+
$this->assertSame(10, $result->getMetadata()->get('cached_prompt_count'));
217+
$this->assertSame(10, $result->getMetadata()->get('cached_completion_count'));
218+
219+
$secondCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [
220+
'messages' => [
221+
[
222+
'role' => 'user',
223+
'content' => 'Say hello world',
224+
],
225+
],
226+
'model' => 'llama3.2',
227+
], [
228+
'prompt_cache_key' => 'foo',
229+
]);
230+
231+
$secondResult = $secondCall->getResult();
232+
233+
$this->assertSame('Hello world', $secondResult->getContent());
234+
$this->assertSame(10, $secondResult->getMetadata()->get('cached_prompt_count'));
235+
$this->assertSame(10, $secondResult->getMetadata()->get('cached_completion_count'));
236+
237+
$this->assertSame(3, $httpClient->getRequestsCount());
238+
}
178239
}

0 commit comments

Comments
 (0)