Skip to content

Commit e038555

Browse files
committed
feat(platform): Ollama prompt cache
1 parent 498745c commit e038555

File tree

9 files changed

+214
-13
lines changed

9 files changed

+214
-13
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
14+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
15+
use Symfony\AI\Platform\Message\Message;
16+
use Symfony\AI\Platform\Message\MessageBag;
17+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
18+
19+
require_once dirname(__DIR__).'/bootstrap.php';
20+
21+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client(), cache: new ArrayAdapter());
22+
$model = new Ollama('llama');
23+
24+
$agent = new Agent($platform, $model, logger: logger());
25+
$messages = new MessageBag(
26+
Message::forSystem('You are a helpful assistant.'),
27+
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
28+
);
29+
$result = $agent->call($messages, [
30+
'prompt_cache_key' => 'chat',
31+
]);
32+
33+
echo $result->getContent().\PHP_EOL;
34+
35+
assert($result->getMetadata()->get('cached'));
36+
assert('chat' === $result->getMetadata()->get('prompt_cache_key'));
37+
assert(0 !== $result->getMetadata()->get('cached_prompt_count'));
38+
assert(0 !== $result->getMetadata()->get('cached_completion_count'));
39+
40+
$secondResult = $agent->call($messages, [
41+
'prompt_cache_key' => 'chat',
42+
]);
43+
44+
echo $secondResult->getContent().\PHP_EOL;
45+
46+
assert($secondResult->getMetadata()->get('cached'));
47+
assert('chat' === $secondResult->getMetadata()->get('prompt_cache_key'));
48+
assert($result->getMetadata()->get('cached_prompt_count') === $secondResult->getMetadata()->get('cached_prompt_count'));
49+
assert($result->getMetadata()->get('cached_completion_count') === $secondResult->getMetadata()->get('cached_completion_count'));
50+
assert($result->getMetadata()->get('cached_time') === $secondResult->getMetadata()->get('cached_time'));

src/ai-bundle/config/options.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
->defaultValue('http_client')
128128
->info('Service ID of the HTTP client to use')
129129
->end()
130+
->scalarNode('cache')->end()
130131
->end()
131132
->end()
132133
->arrayNode('cerebras')

src/ai-bundle/src/AiBundle.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,16 +405,22 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
405405
}
406406

407407
if ('ollama' === $type) {
408+
$arguments = [
409+
$platform['host_url'],
410+
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
411+
new Reference('ai.platform.contract.ollama'),
412+
];
413+
414+
if (\array_key_exists('cache', $platform)) {
415+
$arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE);
416+
}
417+
408418
$platformId = 'ai.platform.ollama';
409419
$definition = (new Definition(Platform::class))
410420
->setFactory(OllamaPlatformFactory::class.'::create')
411421
->setLazy(true)
412422
->addTag('proxy', ['interface' => PlatformInterface::class])
413-
->setArguments([
414-
$platform['host_url'],
415-
new Reference($platform['http_client'], ContainerInterface::NULL_ON_INVALID_REFERENCE),
416-
new Reference('ai.platform.contract.ollama'),
417-
])
423+
->setArguments($arguments)
418424
->addTag('ai.platform');
419425

420426
$container->setDefinition($platformId, $definition);

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,39 @@ public function testIndexerWithSourceFiltersAndTransformers()
21172117
$this->assertSame('logger', (string) $arguments[6]);
21182118
}
21192119

2120+
public function testPromptCachingCanBeUsedWithOllama()
2121+
{
2122+
$container = $this->buildContainer([
2123+
'ai' => [
2124+
'platform' => [
2125+
'ollama' => [
2126+
'host_url' => 'http://127.0.0.1:11434',
2127+
'cache' => 'cache.app',
2128+
],
2129+
],
2130+
],
2131+
]);
2132+
2133+
$this->assertTrue($container->hasDefinition('ai.platform.ollama'));
2134+
2135+
$definition = $container->getDefinition('ai.platform.ollama');
2136+
$this->assertCount(4, $definition->getArguments());
2137+
2138+
$this->assertSame('http://127.0.0.1:11434', $definition->getArgument(0));
2139+
2140+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
2141+
$httpClientArgument = $definition->getArgument(1);
2142+
$this->assertSame('http_client', (string) $httpClientArgument);
2143+
2144+
$this->assertInstanceOf(Reference::class, $definition->getArgument(2));
2145+
$contractArgument = $definition->getArgument(2);
2146+
$this->assertSame('ai.platform.contract.ollama', (string) $contractArgument);
2147+
2148+
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
2149+
$cacheArgument = $definition->getArgument(3);
2150+
$this->assertSame('cache.app', (string) $cacheArgument);
2151+
}
2152+
21202153
private function buildContainer(array $configuration): ContainerBuilder
21212154
{
21222155
$container = new ContainerBuilder();

src/platform/composer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"phpstan/phpstan-symfony": "^2.0.6",
6464
"phpunit/phpunit": "^11.5",
6565
"symfony/ai-agent": "@dev",
66+
"symfony/cache": "^7.3|^8.0",
6667
"symfony/console": "^7.3|^8.0",
6768
"symfony/dotenv": "^7.3|^8.0",
6869
"symfony/event-dispatcher": "^7.3|^8.0",

src/platform/src/Bridge/Ollama/OllamaClient.php

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\ModelClientInterface;
1717
use Symfony\AI\Platform\Result\RawHttpResult;
18+
use Symfony\Component\HttpClient\Response\MockResponse;
19+
use Symfony\Contracts\Cache\CacheInterface;
1820
use Symfony\Contracts\HttpClient\HttpClientInterface;
21+
use Symfony\Contracts\HttpClient\ResponseInterface;
1922

2023
/**
2124
* @author Christopher Hertel <[email protected]>
@@ -25,6 +28,7 @@
2528
public function __construct(
2629
private HttpClientInterface $httpClient,
2730
private string $hostUrl,
31+
private ?CacheInterface $cache = null,
2832
) {
2933
}
3034

@@ -68,10 +72,40 @@ private function doCompletionRequest(array|string $payload, array $options = [])
6872
unset($options['response_format']);
6973
}
7074

71-
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
72-
'headers' => ['Content-Type' => 'application/json'],
73-
'json' => array_merge($options, $payload),
74-
]));
75+
$requestCallback = fn ($options, $payload): ResponseInterface => $this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
76+
'headers' => [
77+
'Content-Type' => 'application/json',
78+
],
79+
'json' => [
80+
...$options,
81+
...$payload,
82+
],
83+
]);
84+
85+
if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && '' !== $options['prompt_cache_key'])) {
86+
$cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5(\is_array($payload) ? json_encode($payload) : ['context' => $payload]));
87+
88+
unset($options['prompt_cache_key']);
89+
90+
$cachedResponse = $this->cache->get($cacheKey, static function () use ($requestCallback, $options, $payload): array {
91+
$response = $requestCallback($options, $payload);
92+
93+
return [
94+
'content' => $response->getContent(),
95+
'headers' => $response->getHeaders(),
96+
'http_code' => $response->getStatusCode(),
97+
];
98+
});
99+
100+
$mockedResponse = new MockResponse($cachedResponse['content'], [
101+
'http_code' => $cachedResponse['http_code'],
102+
'response_headers' => $cachedResponse['headers'],
103+
]);
104+
105+
return new RawHttpResult(MockResponse::fromRequest('POST', \sprintf('%s/api/chat', $this->hostUrl), $options, $mockedResponse));
106+
}
107+
108+
return new RawHttpResult($requestCallback($options, $payload));
75109
}
76110

77111
/**

src/platform/src/Bridge/Ollama/OllamaResultConverter.php

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ public function convert(RawResultInterface $result, array $options = []): Result
4747

4848
return \array_key_exists('embeddings', $data)
4949
? $this->doConvertEmbeddings($data)
50-
: $this->doConvertCompletion($data);
50+
: $this->doConvertCompletion($data, $options);
5151
}
5252

5353
/**
5454
* @param array<string, mixed> $data
55+
* @param array<string, mixed> $options
5556
*/
56-
public function doConvertCompletion(array $data): ResultInterface
57+
public function doConvertCompletion(array $data, array $options): ResultInterface
5758
{
5859
if (!isset($data['message'])) {
5960
throw new RuntimeException('Response does not contain message.');
@@ -73,7 +74,19 @@ public function doConvertCompletion(array $data): ResultInterface
7374
return new ToolCallResult(...$toolCalls);
7475
}
7576

76-
return new TextResult($data['message']['content']);
77+
$result = new TextResult($data['message']['content']);
78+
79+
if (\array_key_exists('prompt_cache_key', $options)) {
80+
$metadata = $result->getMetadata();
81+
82+
$metadata->add('cached', true);
83+
$metadata->add('prompt_cache_key', $options['prompt_cache_key']);
84+
$metadata->add('cached_prompt_count', $data['prompt_eval_count']);
85+
$metadata->add('cached_completion_count', $data['eval_count']);
86+
$metadata->add('cached_time', (new \DateTimeImmutable())->getTimestamp());
87+
}
88+
89+
return $result;
7790
}
7891

7992
/**

src/platform/src/Bridge/Ollama/PlatformFactory.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use Symfony\AI\Platform\Contract;
1616
use Symfony\AI\Platform\Platform;
1717
use Symfony\Component\HttpClient\EventSourceHttpClient;
18+
use Symfony\Contracts\Cache\CacheInterface;
1819
use Symfony\Contracts\HttpClient\HttpClientInterface;
1920

2021
/**
@@ -26,11 +27,12 @@ public static function create(
2627
string $hostUrl = 'http://localhost:11434',
2728
?HttpClientInterface $httpClient = null,
2829
?Contract $contract = null,
30+
?CacheInterface $cache = null,
2931
): Platform {
3032
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
3133

3234
return new Platform(
33-
[new OllamaClient($httpClient, $hostUrl)],
35+
[new OllamaClient($httpClient, $hostUrl, $cache)],
3436
[new OllamaResultConverter()],
3537
$contract ?? OllamaContract::create()
3638
);

src/platform/tests/Bridge/Ollama/OllamaClientTest.php

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
use Symfony\AI\Platform\Model;
2222
use Symfony\AI\Platform\Result\RawHttpResult;
2323
use Symfony\AI\Platform\Result\StreamResult;
24+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
2425
use Symfony\Component\HttpClient\MockHttpClient;
2526
use Symfony\Component\HttpClient\Response\JsonMockResponse;
2627
use Symfony\Component\HttpClient\Response\MockResponse;
@@ -177,4 +178,64 @@ public function testStreamingConverterWithDirectResponse()
177178

178179
$this->assertNotInstanceOf(StreamResult::class, $regularResult);
179180
}
181+
182+
public function testPromptCachingIsSupported()
183+
{
184+
$httpClient = new MockHttpClient([
185+
new JsonMockResponse([
186+
'capabilities' => ['completion'],
187+
]),
188+
new JsonMockResponse([
189+
'model' => 'llama3.2',
190+
'created_at' => '2025-08-23T10:00:01Z',
191+
'message' => ['role' => 'assistant', 'content' => 'Hello world'],
192+
'prompt_eval_count' => 10,
193+
'eval_count' => 10,
194+
'done' => true,
195+
]),
196+
new JsonMockResponse([
197+
'capabilities' => ['completion'],
198+
]),
199+
]);
200+
201+
$platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter());
202+
203+
$firstCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [
204+
'messages' => [
205+
[
206+
'role' => 'user',
207+
'content' => 'Say hello world',
208+
],
209+
],
210+
'model' => 'llama3.2',
211+
], [
212+
'prompt_cache_key' => 'foo',
213+
]);
214+
215+
$result = $firstCall->getResult();
216+
217+
$this->assertSame('Hello world', $result->getContent());
218+
$this->assertSame(10, $result->getMetadata()->get('cached_prompt_count'));
219+
$this->assertSame(10, $result->getMetadata()->get('cached_completion_count'));
220+
221+
$secondCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [
222+
'messages' => [
223+
[
224+
'role' => 'user',
225+
'content' => 'Say hello world',
226+
],
227+
],
228+
'model' => 'llama3.2',
229+
], [
230+
'prompt_cache_key' => 'foo',
231+
]);
232+
233+
$secondResult = $secondCall->getResult();
234+
235+
$this->assertSame('Hello world', $secondResult->getContent());
236+
$this->assertSame(10, $secondResult->getMetadata()->get('cached_prompt_count'));
237+
$this->assertSame(10, $secondResult->getMetadata()->get('cached_completion_count'));
238+
239+
$this->assertSame(3, $httpClient->getRequestsCount());
240+
}
180241
}

0 commit comments

Comments
 (0)