diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 91ab8bbe..5e079933 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -167,6 +167,8 @@ ->arrayPrototype() ->children() ->scalarNode('service')->cannotBeEmpty()->defaultValue('cache.app')->end() + ->scalarNode('cache_key')->end() + ->scalarNode('strategy')->end() ->end() ->end() ->end() @@ -215,7 +217,7 @@ ->useAttributeAsKey('name') ->arrayPrototype() ->children() - ->scalarNode('distance')->cannotBeEmpty()->end() + ->scalarNode('strategy')->cannotBeEmpty()->end() ->end() ->end() ->end() diff --git a/src/ai-bundle/doc/index.rst b/src/ai-bundle/doc/index.rst index fe409dff..711e1b75 100644 --- a/src/ai-bundle/doc/index.rst +++ b/src/ai-bundle/doc/index.rst @@ -91,6 +91,14 @@ Configuration # multiple collections possible per type default: collection: 'my_collection' + cache: + research: + service: 'cache.app' + cache_key: 'research' + strategy: 'chebyshev' + memory: + ollama: + strategy: 'manhattan' indexer: default: # platform: 'ai.platform.mistral' diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 04c391f7..0bb6e3f2 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -27,13 +27,13 @@ use Symfony\AI\AiBundle\Security\Attribute\IsGrantedTool; use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory; use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory; +use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory; use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory; use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory; use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; -use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelClientInterface; use Symfony\AI\Platform\Platform; @@ -50,6 +50,8 @@ use Symfony\AI\Store\Bridge\SurrealDb\Store as SurrealDbStore; use Symfony\AI\Store\Bridge\Typesense\Store as TypesenseStore; use Symfony\AI\Store\CacheStore; +use Symfony\AI\Store\DistanceCalculator; +use Symfony\AI\Store\DistanceStrategy; use Symfony\AI\Store\Document\Vectorizer; use Symfony\AI\Store\Indexer; use Symfony\AI\Store\InMemoryStore; @@ -494,8 +496,24 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde foreach ($stores as $name => $store) { $arguments = [ new Reference($store['service']), + new Definition(DistanceCalculator::class), ]; + if (\array_key_exists('cache_key', $store) && null !== $store['cache_key']) { + $arguments[2] = $store['cache_key']; + } + + if (\array_key_exists('strategy', $store) && null !== $store['strategy']) { + if (!$container->hasDefinition('ai.store.distance_calculator.'.$name)) { + $distanceCalculatorDefinition = new Definition(DistanceCalculator::class); + $distanceCalculatorDefinition->setArgument(0, DistanceStrategy::from($store['strategy'])); + + $container->setDefinition('ai.store.distance_calculator.'.$name, $distanceCalculatorDefinition); + } + + $arguments[1] = new Reference('ai.store.distance_calculator.'.$name); + } + $definition = new Definition(CacheStore::class); $definition ->addTag('ai.store') @@ -577,9 +595,18 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde if ('memory' === $type) { foreach ($stores as $name => $store) { - $arguments = [ - $store['distance'], - ]; + $arguments = []; + + if (\array_key_exists('strategy', $store) && null !== $store['strategy']) { + if (!$container->hasDefinition('ai.store.distance_calculator.'.$name)) { + $distanceCalculatorDefinition = new Definition(DistanceCalculator::class); + $distanceCalculatorDefinition->setArgument(0, DistanceStrategy::from($store['strategy'])); + + $container->setDefinition('ai.store.distance_calculator.'.$name, $distanceCalculatorDefinition); + } + + $arguments[0] = new Reference('ai.store.distance_calculator.'.$name); + } $definition = new Definition(InMemoryStore::class); $definition diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index ba0b1993..4cc22894 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -19,9 +19,13 @@ use Symfony\AI\AiBundle\AiBundle; use Symfony\Component\Config\Definition\Exception\InvalidConfigurationException; use Symfony\Component\DependencyInjection\ContainerBuilder; +use Symfony\Component\DependencyInjection\Definition; +use Symfony\Component\DependencyInjection\Reference; #[CoversClass(AiBundle::class)] #[UsesClass(ContainerBuilder::class)] +#[UsesClass(Definition::class)] +#[UsesClass(Reference::class)] class AiBundleTest extends TestCase { #[DoesNotPerformAssertions] @@ -105,6 +109,130 @@ public function testAgentsAsToolsCannotDefineService() ]); } + public function testCacheStoreWithCustomKeyCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'cache' => [ + 'my_cache_store_with_custom_strategy' => [ + 'service' => 'cache.system', + 'cache_key' => 'random', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy')); + $this->assertFalse($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy')); + + $definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy'); + + $this->assertCount(3, $definition->getArguments()); + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $this->assertSame('cache.system', (string) $definition->getArgument(0)); + $this->assertSame('random', $definition->getArgument(2)); + } + + public function testCacheStoreWithCustomStrategyCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'cache' => [ + 'my_cache_store_with_custom_strategy' => [ + 'service' => 'cache.system', + 'strategy' => 'chebyshev', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy')); + $this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy')); + + $definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy'); + + $this->assertCount(2, $definition->getArguments()); + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $this->assertSame('cache.system', (string) $definition->getArgument(0)); + $this->assertInstanceOf(Reference::class, $definition->getArgument(1)); + $this->assertSame('ai.store.distance_calculator.my_cache_store_with_custom_strategy', (string) $definition->getArgument(1)); + } + + public function testCacheStoreWithCustomStrategyAndKeyCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'cache' => [ + 'my_cache_store_with_custom_strategy' => [ + 'service' => 'cache.system', + 'cache_key' => 'random', + 'strategy' => 'chebyshev', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy')); + $this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy')); + + $definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy'); + + $this->assertCount(3, $definition->getArguments()); + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $this->assertSame('cache.system', (string) $definition->getArgument(0)); + $this->assertSame('random', $definition->getArgument(2)); + $this->assertInstanceOf(Reference::class, $definition->getArgument(1)); + $this->assertSame('ai.store.distance_calculator.my_cache_store_with_custom_strategy', (string) $definition->getArgument(1)); + } + + public function testInMemoryStoreWithoutCustomStrategyCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'memory' => [ + 'my_memory_store_with_custom_strategy' => [], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.memory.my_memory_store_with_custom_strategy')); + + $definition = $container->getDefinition('ai.store.memory.my_memory_store_with_custom_strategy'); + $this->assertCount(0, $definition->getArguments()); + } + + public function testInMemoryStoreWithCustomStrategyCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'memory' => [ + 'my_memory_store_with_custom_strategy' => [ + 'strategy' => 'chebyshev', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.memory.my_memory_store_with_custom_strategy')); + $this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_memory_store_with_custom_strategy')); + + $definition = $container->getDefinition('ai.store.memory.my_memory_store_with_custom_strategy'); + + $this->assertCount(1, $definition->getArguments()); + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $this->assertSame('ai.store.distance_calculator.my_memory_store_with_custom_strategy', (string) $definition->getArgument(0)); + } + private function buildContainer(array $configuration): ContainerBuilder { $container = new ContainerBuilder(); @@ -205,6 +333,19 @@ private function getFullConfig(): array 'my_cache_store' => [ 'service' => 'cache.system', ], + 'my_cache_store_with_custom_key' => [ + 'service' => 'cache.system', + 'cache_key' => 'bar', + ], + 'my_cache_store_with_custom_strategy' => [ + 'service' => 'cache.system', + 'strategy' => 'chebyshev', + ], + 'my_cache_store_with_custom_strategy_and_custom_key' => [ + 'service' => 'cache.system', + 'cache_key' => 'bar', + 'strategy' => 'chebyshev', + ], ], 'chroma_db' => [ 'my_chroma_store' => [ @@ -230,7 +371,7 @@ private function getFullConfig(): array ], 'memory' => [ 'my_memory_store' => [ - 'distance' => 'cosine', + 'strategy' => 'cosine', ], ], 'mongodb' => [ diff --git a/src/store/src/CacheStore.php b/src/store/src/CacheStore.php index 1223eda2..8d0aa781 100644 --- a/src/store/src/CacheStore.php +++ b/src/store/src/CacheStore.php @@ -32,6 +32,10 @@ public function __construct( if (!interface_exists(CacheItemPoolInterface::class)) { throw new RuntimeException('For using the CacheStore as vector store, a PSR-6 cache implementation is required. Try running "composer require symfony/cache" or another PSR-6 compatible cache.'); } + + if (!interface_exists(CacheInterface::class)) { + throw new RuntimeException('For using the CacheStore as vector store, a symfony/contracts cache implementation is required. Try running "composer require symfony/cache" or another symfony/contracts compatible cache.'); + } } public function add(VectorDocument ...$documents): void