-
-
Notifications
You must be signed in to change notification settings - Fork 102
[Platform][Ollama] Add prompt cache #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
<?php | ||
|
||
/* | ||
* This file is part of the Symfony package. | ||
* | ||
* (c) Fabien Potencier <[email protected]> | ||
* | ||
* For the full copyright and license information, please view the LICENSE | ||
* file that was distributed with this source code. | ||
*/ | ||
|
||
use Symfony\AI\Agent\Agent; | ||
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; | ||
use Symfony\AI\Platform\CachedPlatform; | ||
use Symfony\AI\Platform\Message\Message; | ||
use Symfony\AI\Platform\Message\MessageBag; | ||
use Symfony\Component\Cache\Adapter\ArrayAdapter; | ||
|
||
require_once dirname(__DIR__).'/bootstrap.php'; | ||
|
||
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client()); | ||
$cachedPlatform = new CachedPlatform($platform, new ArrayAdapter()); | ||
|
||
$agent = new Agent($cachedPlatform, 'gemma3n', logger: logger()); | ||
$messages = new MessageBag( | ||
Message::forSystem('You are a helpful assistant.'), | ||
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), | ||
); | ||
$result = $agent->call($messages, [ | ||
'prompt_cache_key' => 'chat', | ||
]); | ||
|
||
echo $result->getContent().\PHP_EOL; | ||
|
||
$secondResult = $agent->call($messages, [ | ||
'prompt_cache_key' => 'chat', | ||
]); | ||
|
||
echo $secondResult->getContent().\PHP_EOL; |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -415,6 +415,16 @@ private function processPlatformConfig(string $type, array $platform, ContainerB | |||||||
} | ||||||||
|
||||||||
if ('ollama' === $type) { | ||||||||
$arguments = [ | ||||||||
$platform['host_url'], | ||||||||
new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE), | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
new Reference('ai.platform.contract.ollama'), | ||||||||
]; | ||||||||
|
||||||||
if (\array_key_exists('cache', $platform)) { | ||||||||
$arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE); | ||||||||
} | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. $arguments is not used |
||||||||
$platformId = 'ai.platform.ollama'; | ||||||||
$definition = (new Definition(Platform::class)) | ||||||||
->setFactory(OllamaPlatformFactory::class.'::create') | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changes here would also belong into |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,13 +47,14 @@ public function convert(RawResultInterface $result, array $options = []): Result | |
|
||
return \array_key_exists('embeddings', $data) | ||
? $this->doConvertEmbeddings($data) | ||
: $this->doConvertCompletion($data); | ||
: $this->doConvertCompletion($data, $options); | ||
} | ||
|
||
/** | ||
* @param array<string, mixed> $data | ||
* @param array<string, mixed> $options | ||
*/ | ||
public function doConvertCompletion(array $data): ResultInterface | ||
public function doConvertCompletion(array $data, array $options): ResultInterface | ||
{ | ||
if (!isset($data['message'])) { | ||
throw new RuntimeException('Response does not contain message.'); | ||
|
@@ -73,7 +74,19 @@ public function doConvertCompletion(array $data): ResultInterface | |
return new ToolCallResult(...$toolCalls); | ||
} | ||
|
||
return new TextResult($data['message']['content']); | ||
$result = new TextResult($data['message']['content']); | ||
|
||
if (\array_key_exists('prompt_cache_key', $options)) { | ||
$metadata = $result->getMetadata(); | ||
|
||
$metadata->add('cached', true); | ||
$metadata->add('prompt_cache_key', $options['prompt_cache_key']); | ||
$metadata->add('cached_prompt_count', $data['prompt_eval_count']); | ||
$metadata->add('cached_completion_count', $data['eval_count']); | ||
Comment on lines
+82
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it make sense to group this data into a DTO, like there is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not convinced about the benefits of using an object here, we're only storing an integer, I don't see the benefits to be honest 🤔 @OskarStark @chr-hertel Any thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i agree, it would be great to have an object like |
||
$metadata->add('cached_time', (new \DateTimeImmutable())->getTimestamp()); | ||
} | ||
|
||
return $result; | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
<?php | ||
|
||
/* | ||
* This file is part of the Symfony package. | ||
* | ||
* (c) Fabien Potencier <[email protected]> | ||
* | ||
* For the full copyright and license information, please view the LICENSE | ||
* file that was distributed with this source code. | ||
*/ | ||
|
||
namespace Symfony\AI\Platform; | ||
|
||
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; | ||
use Symfony\AI\Platform\Result\ResultPromise; | ||
use Symfony\Contracts\Cache\CacheInterface; | ||
|
||
/** | ||
* @author Guillaume Loulier <[email protected]> | ||
*/ | ||
final readonly class CachedPlatform implements PlatformInterface | ||
{ | ||
public function __construct( | ||
private PlatformInterface $platform, | ||
private CacheInterface $cache, | ||
) { | ||
} | ||
|
||
public function invoke(string $model, object|array|string $input, array $options = []): ResultPromise | ||
{ | ||
$invokeCall = fn (string $model, object|array|string $input, array $options = []): ResultPromise => $this->platform->invoke($model, $input, $options); | ||
|
||
if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && '' !== $options['prompt_cache_key'])) { | ||
$cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5($model)); | ||
|
||
unset($options['prompt_cache_key']); | ||
|
||
return $this->cache->get($cacheKey, static fn (): ResultPromise => $invokeCall($model, $input, $options)); | ||
} | ||
|
||
return $invokeCall($model, $input, $options); | ||
} | ||
|
||
public function getModelCatalog(): ModelCatalogInterface | ||
{ | ||
return $this->platform->getModelCatalog(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\AI\Platform\Model; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\AI\Platform\Result\RawHttpResult; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\AI\Platform\Result\StreamResult; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\Component\Cache\Adapter\ArrayAdapter; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\Component\HttpClient\MockHttpClient; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\Component\HttpClient\Response\JsonMockResponse; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use Symfony\Component\HttpClient\Response\MockResponse; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -172,4 +173,64 @@ public function testStreamingConverterWithDirectResponse() | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertNotInstanceOf(StreamResult::class, $regularResult); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
public function testPromptCachingIsSupported() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$httpClient = new MockHttpClient([ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new JsonMockResponse([ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'capabilities' => ['completion'], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new JsonMockResponse([ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'model' => 'llama3.2', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'created_at' => '2025-08-23T10:00:01Z', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'message' => ['role' => 'assistant', 'content' => 'Hello world'], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'prompt_eval_count' => 10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'eval_count' => 10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'done' => true, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new JsonMockResponse([ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'capabilities' => ['completion'], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$firstCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'messages' => [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'role' => 'user', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'content' => 'Say hello world', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'model' => 'llama3.2', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'prompt_cache_key' => 'foo', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$result = $firstCall->getResult(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame('Hello world', $result->getContent()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame(10, $result->getMetadata()->get('cached_prompt_count')); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame(10, $result->getMetadata()->get('cached_completion_count')); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$secondCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'messages' => [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'role' => 'user', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'content' => 'Say hello world', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'model' => 'llama3.2', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'prompt_cache_key' => 'foo', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$secondResult = $secondCall->getResult(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame('Hello world', $secondResult->getContent()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+198
to
+230
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame(10, $secondResult->getMetadata()->get('cached_prompt_count')); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame(10, $secondResult->getMetadata()->get('cached_completion_count')); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
$this->assertSame(3, $httpClient->getRequestsCount()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
<?php | ||
|
||
/* | ||
* This file is part of the Symfony package. | ||
* | ||
* (c) Fabien Potencier <[email protected]> | ||
* | ||
* For the full copyright and license information, please view the LICENSE | ||
* file that was distributed with this source code. | ||
*/ | ||
|
||
namespace Symfony\AI\Platform\Tests; | ||
|
||
use PHPUnit\Framework\TestCase; | ||
|
||
final class CachedPlatformTest extends TestCase | ||
{ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might end with the same cache repeated again and again in every platform.
Should we introduce a cache config key at a higher level ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question, IMHO, we should introduce a "root" key for it and allowing the override "per platform", @OskarStark @chr-hertel, any idea?