Skip to content

Commit 55d8d64

Browse files
committed
feat(platform): Ollama prompt cache
1 parent 0bd8d96 commit 55d8d64

File tree

8 files changed

+180
-10
lines changed

8 files changed

+180
-10
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Platform\Bridge\Ollama\Ollama;
14+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
15+
use Symfony\AI\Platform\Message\Message;
16+
use Symfony\AI\Platform\Message\MessageBag;
17+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
18+
19+
require_once dirname(__DIR__).'/bootstrap.php';
20+
21+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client(), cache: new ArrayAdapter());
22+
$model = new Ollama();
23+
24+
$agent = new Agent($platform, $model, logger: logger());
25+
$messages = new MessageBag(
26+
Message::forSystem('You are a helpful assistant.'),
27+
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
28+
);
29+
$result = $agent->call($messages, [
30+
'prompt_cache_key' => 'chat',
31+
]);
32+
33+
echo $result->getContent().\PHP_EOL;
34+
35+
$secondResult = $agent->call($messages, [
36+
'prompt_cache_key' => 'chat',
37+
]);
38+
39+
echo $result->getContent().\PHP_EOL;

src/ai-bundle/config/options.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
->arrayNode('ollama')
8989
->children()
9090
->scalarNode('host_url')->defaultValue('http://127.0.0.1:11434')->end()
91+
->scalarNode('cache')->end()
9192
->end()
9293
->end()
9394
->arrayNode('cerebras')

src/ai-bundle/src/AiBundle.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,17 +397,23 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
397397
}
398398

399399
if ('ollama' === $type) {
400+
$arguments = [
401+
$platform['host_url'],
402+
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
403+
new Reference('ai.platform.contract.ollama'),
404+
];
405+
406+
if (\array_key_exists('cache', $platform)) {
407+
$arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE);
408+
}
409+
400410
$platformId = 'ai.platform.ollama';
401411
$definition = (new Definition(Platform::class))
402412
->setFactory(MistralPlatformFactory::class.'::create')
403413
->setFactory(OllamaPlatformFactory::class.'::create')
404414
->setLazy(true)
405415
->addTag('proxy', ['interface' => PlatformInterface::class])
406-
->setArguments([
407-
$platform['host_url'],
408-
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE),
409-
new Reference('ai.platform.contract.ollama'),
410-
])
416+
->setArguments($arguments)
411417
->addTag('ai.platform');
412418

413419
$container->setDefinition($platformId, $definition);

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,39 @@ public function testOpenAiPlatformWithInvalidRegion()
591591
]);
592592
}
593593

594+
public function testPromptCachingCanBeUsedWithOllama()
595+
{
596+
$container = $this->buildContainer([
597+
'ai' => [
598+
'platform' => [
599+
'ollama' => [
600+
'host_url' => 'http://127.0.0.1:11434',
601+
'cache' => 'cache.app',
602+
],
603+
],
604+
],
605+
]);
606+
607+
$this->assertTrue($container->hasDefinition('ai.platform.ollama'));
608+
609+
$definition = $container->getDefinition('ai.platform.ollama');
610+
$this->assertCount(4, $definition->getArguments());
611+
612+
$this->assertSame('http://127.0.0.1:11434', $definition->getArgument(0));
613+
614+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
615+
$httpClientArgument = $definition->getArgument(1);
616+
$this->assertSame('http_client', (string) $httpClientArgument);
617+
618+
$this->assertInstanceOf(Reference::class, $definition->getArgument(2));
619+
$contractArgument = $definition->getArgument(2);
620+
$this->assertSame('ai.platform.contract.ollama', (string) $contractArgument);
621+
622+
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
623+
$cacheArgument = $definition->getArgument(3);
624+
$this->assertSame('cache.app', (string) $cacheArgument);
625+
}
626+
594627
private function buildContainer(array $configuration): ContainerBuilder
595628
{
596629
$container = new ContainerBuilder();

src/platform/composer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"phpstan/phpstan-symfony": "^2.0.6",
6060
"phpunit/phpunit": "^11.5",
6161
"symfony/ai-agent": "@dev",
62+
"symfony/cache": "^6.4 || ^7.1",
6263
"symfony/console": "^6.4 || ^7.1",
6364
"symfony/dotenv": "^6.4 || ^7.1",
6465
"symfony/event-dispatcher": "^6.4 || ^7.1",

src/platform/src/Bridge/Ollama/OllamaClient.php

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\ModelClientInterface;
1717
use Symfony\AI\Platform\Result\RawHttpResult;
18+
use Symfony\Component\HttpClient\Response\MockResponse;
19+
use Symfony\Contracts\Cache\CacheInterface;
1820
use Symfony\Contracts\HttpClient\HttpClientInterface;
21+
use Symfony\Contracts\HttpClient\ResponseInterface;
1922

2023
/**
2124
* @author Christopher Hertel <[email protected]>
@@ -25,6 +28,7 @@
2528
public function __construct(
2629
private HttpClientInterface $httpClient,
2730
private string $hostUrl,
31+
private ?CacheInterface $cache = null,
2832
) {
2933
}
3034

@@ -68,10 +72,40 @@ private function doCompletionRequest(array|string $payload, array $options = [])
6872
unset($options['response_format']);
6973
}
7074

71-
return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
72-
'headers' => ['Content-Type' => 'application/json'],
73-
'json' => array_merge($options, $payload),
74-
]));
75+
$requestCallback = fn ($options, $payload): ResponseInterface => $this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [
76+
'headers' => [
77+
'Content-Type' => 'application/json',
78+
],
79+
'json' => [
80+
...$options,
81+
...$payload,
82+
],
83+
]);
84+
85+
if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && $options['prompt_cache_key'])) {
86+
$cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5(\is_array($payload) ? json_encode($payload) : ['context' => $payload]));
87+
88+
unset($options['prompt_cache_key']);
89+
90+
$cachedResponse = $this->cache->get($cacheKey, static function () use ($requestCallback, $options, $payload): array {
91+
$response = $requestCallback($options, $payload);
92+
93+
return [
94+
'content' => $response->getContent(),
95+
'headers' => $response->getHeaders(),
96+
'http_code' => $response->getStatusCode(),
97+
];
98+
});
99+
100+
$mockedResponse = new MockResponse($cachedResponse['content'], [
101+
'http_code' => $cachedResponse['http_code'],
102+
'response_headers' => $cachedResponse['headers'],
103+
]);
104+
105+
return new RawHttpResult(MockResponse::fromRequest('POST', \sprintf('%s/api/chat', $this->hostUrl), $options, $mockedResponse));
106+
}
107+
108+
return new RawHttpResult($requestCallback($options, $payload));
75109
}
76110

77111
/**

src/platform/src/Bridge/Ollama/PlatformFactory.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use Symfony\AI\Platform\Contract;
1616
use Symfony\AI\Platform\Platform;
1717
use Symfony\Component\HttpClient\EventSourceHttpClient;
18+
use Symfony\Contracts\Cache\CacheInterface;
1819
use Symfony\Contracts\HttpClient\HttpClientInterface;
1920

2021
/**
@@ -26,11 +27,12 @@ public static function create(
2627
string $hostUrl = 'http://localhost:11434',
2728
?HttpClientInterface $httpClient = null,
2829
?Contract $contract = null,
30+
?CacheInterface $cache = null,
2931
): Platform {
3032
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
3133

3234
return new Platform(
33-
[new OllamaClient($httpClient, $hostUrl)],
35+
[new OllamaClient($httpClient, $hostUrl, $cache)],
3436
[new OllamaResultConverter()],
3537
$contract ?? OllamaContract::create()
3638
);

src/platform/tests/Bridge/Ollama/OllamaClientTest.php

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
2020
use Symfony\AI\Platform\Model;
2121
use Symfony\AI\Platform\Result\StreamResult;
22+
use Symfony\Component\Cache\Adapter\ArrayAdapter;
2223
use Symfony\Component\HttpClient\MockHttpClient;
2324
use Symfony\Component\HttpClient\Response\JsonMockResponse;
2425
use Symfony\Component\HttpClient\Response\MockResponse;
@@ -175,4 +176,57 @@ public function testStreamingConverterWithDirectResponse()
175176

176177
$this->assertNotInstanceOf(StreamResult::class, $regularResult);
177178
}
179+
180+
public function testPromptCachingIsSupported()
181+
{
182+
$httpClient = new MockHttpClient([
183+
new JsonMockResponse([
184+
'capabilities' => ['completion'],
185+
]),
186+
new JsonMockResponse([
187+
'model' => 'llama3.2',
188+
'created_at' => '2025-08-23T10:00:01Z',
189+
'message' => ['role' => 'assistant', 'content' => 'Hello world'],
190+
'done' => true,
191+
]),
192+
new JsonMockResponse([
193+
'capabilities' => ['completion'],
194+
]),
195+
]);
196+
197+
$platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter());
198+
199+
$firstCall = $platform->invoke(new Ollama(), [
200+
'messages' => [
201+
[
202+
'role' => 'user',
203+
'content' => 'Say hello world',
204+
],
205+
],
206+
'model' => 'llama3.2',
207+
], [
208+
'prompt_cache_key' => 'foo',
209+
]);
210+
211+
$result = $firstCall->getResult();
212+
213+
$this->assertSame('Hello world', $result->getContent());
214+
215+
$secondCall = $platform->invoke(new Ollama(), [
216+
'messages' => [
217+
[
218+
'role' => 'user',
219+
'content' => 'Say hello world',
220+
],
221+
],
222+
'model' => 'llama3.2',
223+
], [
224+
'prompt_cache_key' => 'foo',
225+
]);
226+
227+
$result = $secondCall->getResult();
228+
229+
$this->assertSame('Hello world', $result->getContent());
230+
$this->assertSame(3, $httpClient->getRequestsCount());
231+
}
178232
}

0 commit comments

Comments
 (0)