diff --git a/examples/.env b/examples/.env index 64bde52b1..dcd2b45dd 100644 --- a/examples/.env +++ b/examples/.env @@ -16,7 +16,7 @@ VOYAGE_API_KEY= REPLICATE_API_KEY= # For using Ollama -OLLAMA_HOST_URL=http://localhost:11434 +OLLAMA_HOST_URL=http://127.0.0.1:11434 OLLAMA_LLM=llama3.2 OLLAMA_EMBEDDINGS=nomic-embed-text diff --git a/examples/misc/ollama-chat-with-cache.php b/examples/misc/ollama-chat-with-cache.php new file mode 100644 index 000000000..b79f2bc58 --- /dev/null +++ b/examples/misc/ollama-chat-with-cache.php @@ -0,0 +1,39 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; +use Symfony\AI\Platform\CachedPlatform; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\Component\Cache\Adapter\ArrayAdapter; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client()); +$cachedPlatform = new CachedPlatform($platform, new ArrayAdapter()); + +$agent = new Agent($cachedPlatform, 'gemma3n', logger: logger()); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$result = $agent->call($messages, [ + 'prompt_cache_key' => 'chat', +]); + +echo $result->getContent().\PHP_EOL; + +$secondResult = $agent->call($messages, [ + 'prompt_cache_key' => 'chat', +]); + +echo $secondResult->getContent().\PHP_EOL; diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 172b887cb..4f0e40c41 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -128,6 +128,7 @@ ->defaultValue('http_client') ->info('Service ID of the HTTP client to use') ->end() + ->scalarNode('cache')->end() ->end() ->end() ->arrayNode('cerebras') diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 1eff064cc..c03b322b3 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -415,6 +415,16 @@ private function processPlatformConfig(string $type, array $platform, ContainerB } if ('ollama' === $type) { + $arguments = [ + $platform['host_url'], + new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE), + new Reference('ai.platform.contract.ollama'), + ]; + + if (\array_key_exists('cache', $platform)) { + $arguments[] = new Reference($platform['cache'], ContainerInterface::NULL_ON_INVALID_REFERENCE); + } + $platformId = 'ai.platform.ollama'; $definition = (new Definition(Platform::class)) ->setFactory(OllamaPlatformFactory::class.'::create') diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index efc53c8f8..1c66b3ea6 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -2087,6 +2087,39 @@ public function testIndexerWithSourceFiltersAndTransformers() $this->assertSame('logger', (string) $arguments[6]); } + public function testPromptCachingCanBeUsedWithOllama() + { + $container = $this->buildContainer([ + 'ai' => [ + 'platform' => [ + 'ollama' => [ + 'host_url' => 'http://127.0.0.1:11434', + 'cache' => 'cache.app', + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.platform.ollama')); + + $definition = $container->getDefinition('ai.platform.ollama'); + $this->assertCount(4, $definition->getArguments()); + + $this->assertSame('http://127.0.0.1:11434', $definition->getArgument(0)); + + $this->assertInstanceOf(Reference::class, $definition->getArgument(1)); + $httpClientArgument = $definition->getArgument(1); + $this->assertSame('http_client', (string) $httpClientArgument); + + $this->assertInstanceOf(Reference::class, $definition->getArgument(2)); + $contractArgument = $definition->getArgument(2); + $this->assertSame('ai.platform.contract.ollama', (string) $contractArgument); + + $this->assertInstanceOf(Reference::class, $definition->getArgument(3)); + $cacheArgument = $definition->getArgument(3); + $this->assertSame('cache.app', (string) $cacheArgument); + } + private function buildContainer(array $configuration): ContainerBuilder { $container = new ContainerBuilder(); diff --git a/src/platform/composer.json b/src/platform/composer.json index 6142164dd..6afa92e91 100644 --- a/src/platform/composer.json +++ b/src/platform/composer.json @@ -64,6 +64,7 @@ "phpstan/phpstan-symfony": "^2.0.6", "phpunit/phpunit": "^11.5", "symfony/ai-agent": "@dev", + "symfony/cache": "^7.3|^8.0", "symfony/console": "^7.3|^8.0", "symfony/dotenv": "^7.3|^8.0", "symfony/event-dispatcher": "^7.3|^8.0", diff --git a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php index ba5b8f896..ea4b249b8 100644 --- a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php +++ b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php @@ -47,13 +47,14 @@ public function convert(RawResultInterface $result, array $options = []): Result return \array_key_exists('embeddings', $data) ? $this->doConvertEmbeddings($data) - : $this->doConvertCompletion($data); + : $this->doConvertCompletion($data, $options); } /** * @param array $data + * @param array $options */ - public function doConvertCompletion(array $data): ResultInterface + public function doConvertCompletion(array $data, array $options): ResultInterface { if (!isset($data['message'])) { throw new RuntimeException('Response does not contain message.'); @@ -73,7 +74,19 @@ public function doConvertCompletion(array $data): ResultInterface return new ToolCallResult(...$toolCalls); } - return new TextResult($data['message']['content']); + $result = new TextResult($data['message']['content']); + + if (\array_key_exists('prompt_cache_key', $options)) { + $metadata = $result->getMetadata(); + + $metadata->add('cached', true); + $metadata->add('prompt_cache_key', $options['prompt_cache_key']); + $metadata->add('cached_prompt_count', $data['prompt_eval_count']); + $metadata->add('cached_completion_count', $data['eval_count']); + $metadata->add('cached_time', (new \DateTimeImmutable())->getTimestamp()); + } + + return $result; } /** diff --git a/src/platform/src/CachedPlatform.php b/src/platform/src/CachedPlatform.php new file mode 100644 index 000000000..1e499b490 --- /dev/null +++ b/src/platform/src/CachedPlatform.php @@ -0,0 +1,48 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; +use Symfony\AI\Platform\Result\ResultPromise; +use Symfony\Contracts\Cache\CacheInterface; + +/** + * @author Guillaume Loulier + */ +final readonly class CachedPlatform implements PlatformInterface +{ + public function __construct( + private PlatformInterface $platform, + private CacheInterface $cache, + ) { + } + + public function invoke(string $model, object|array|string $input, array $options = []): ResultPromise + { + $invokeCall = fn (string $model, object|array|string $input, array $options = []): ResultPromise => $this->platform->invoke($model, $input, $options); + + if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && '' !== $options['prompt_cache_key'])) { + $cacheKey = \sprintf('%s_%s', $options['prompt_cache_key'], md5($model)); + + unset($options['prompt_cache_key']); + + return $this->cache->get($cacheKey, static fn (): ResultPromise => $invokeCall($model, $input, $options)); + } + + return $invokeCall($model, $input, $options); + } + + public function getModelCatalog(): ModelCatalogInterface + { + return $this->platform->getModelCatalog(); + } +} diff --git a/src/platform/tests/Bridge/Ollama/OllamaClientTest.php b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php index 83e252259..6633eac59 100644 --- a/src/platform/tests/Bridge/Ollama/OllamaClientTest.php +++ b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php @@ -19,6 +19,7 @@ use Symfony\AI\Platform\Model; use Symfony\AI\Platform\Result\RawHttpResult; use Symfony\AI\Platform\Result\StreamResult; +use Symfony\Component\Cache\Adapter\ArrayAdapter; use Symfony\Component\HttpClient\MockHttpClient; use Symfony\Component\HttpClient\Response\JsonMockResponse; use Symfony\Component\HttpClient\Response\MockResponse; @@ -172,4 +173,64 @@ public function testStreamingConverterWithDirectResponse() $this->assertNotInstanceOf(StreamResult::class, $regularResult); } + + public function testPromptCachingIsSupported() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([ + 'capabilities' => ['completion'], + ]), + new JsonMockResponse([ + 'model' => 'llama3.2', + 'created_at' => '2025-08-23T10:00:01Z', + 'message' => ['role' => 'assistant', 'content' => 'Hello world'], + 'prompt_eval_count' => 10, + 'eval_count' => 10, + 'done' => true, + ]), + new JsonMockResponse([ + 'capabilities' => ['completion'], + ]), + ]); + + $platform = PlatformFactory::create('http://127.0.0.1:1234', $httpClient, cache: new ArrayAdapter()); + + $firstCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [ + 'messages' => [ + [ + 'role' => 'user', + 'content' => 'Say hello world', + ], + ], + 'model' => 'llama3.2', + ], [ + 'prompt_cache_key' => 'foo', + ]); + + $result = $firstCall->getResult(); + + $this->assertSame('Hello world', $result->getContent()); + $this->assertSame(10, $result->getMetadata()->get('cached_prompt_count')); + $this->assertSame(10, $result->getMetadata()->get('cached_completion_count')); + + $secondCall = $platform->invoke(new Ollama(Ollama::LLAMA_3_2), [ + 'messages' => [ + [ + 'role' => 'user', + 'content' => 'Say hello world', + ], + ], + 'model' => 'llama3.2', + ], [ + 'prompt_cache_key' => 'foo', + ]); + + $secondResult = $secondCall->getResult(); + + $this->assertSame('Hello world', $secondResult->getContent()); + $this->assertSame(10, $secondResult->getMetadata()->get('cached_prompt_count')); + $this->assertSame(10, $secondResult->getMetadata()->get('cached_completion_count')); + + $this->assertSame(3, $httpClient->getRequestsCount()); + } } diff --git a/src/platform/tests/CachedPlatformTest.php b/src/platform/tests/CachedPlatformTest.php new file mode 100644 index 000000000..e48491e67 --- /dev/null +++ b/src/platform/tests/CachedPlatformTest.php @@ -0,0 +1,18 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests; + +use PHPUnit\Framework\TestCase; + +final class CachedPlatformTest extends TestCase +{ +}