Skip to content

Commit 2368215

Browse files
committed
fix(aibundle): cache store configuration
1 parent a50d26e commit 2368215

File tree

5 files changed

+188
-6
lines changed

5 files changed

+188
-6
lines changed

src/ai-bundle/config/options.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@
167167
->arrayPrototype()
168168
->children()
169169
->scalarNode('service')->cannotBeEmpty()->defaultValue('cache.app')->end()
170+
->scalarNode('cache_key')->end()
171+
->scalarNode('strategy')->end()
170172
->end()
171173
->end()
172174
->end()
@@ -215,7 +217,7 @@
215217
->useAttributeAsKey('name')
216218
->arrayPrototype()
217219
->children()
218-
->scalarNode('distance')->cannotBeEmpty()->end()
220+
->scalarNode('strategy')->cannotBeEmpty()->end()
219221
->end()
220222
->end()
221223
->end()

src/ai-bundle/doc/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ Configuration
9191
# multiple collections possible per type
9292
default:
9393
collection: 'my_collection'
94+
cache:
95+
ollama:
96+
cache:
97+
service: 'cache.app'
98+
cache_key: 'ollama'
99+
strategy: 'chebyshev'
100+
memory:
101+
strategy: 'chebyshev'
94102
indexer:
95103
default:
96104
# platform: 'ai.platform.mistral'

src/ai-bundle/src/AiBundle.php

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
use Symfony\AI\AiBundle\Security\Attribute\IsGrantedTool;
2828
use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory;
2929
use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory;
30+
use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory;
3031
use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory;
3132
use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory;
3233
use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory;
3334
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory;
3435
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory;
3536
use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory;
36-
use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory;
3737
use Symfony\AI\Platform\Model;
3838
use Symfony\AI\Platform\ModelClientInterface;
3939
use Symfony\AI\Platform\Platform;
@@ -50,6 +50,8 @@
5050
use Symfony\AI\Store\Bridge\SurrealDb\Store as SurrealDbStore;
5151
use Symfony\AI\Store\Bridge\Typesense\Store as TypesenseStore;
5252
use Symfony\AI\Store\CacheStore;
53+
use Symfony\AI\Store\DistanceCalculator;
54+
use Symfony\AI\Store\DistanceStrategy;
5355
use Symfony\AI\Store\Document\Vectorizer;
5456
use Symfony\AI\Store\Indexer;
5557
use Symfony\AI\Store\InMemoryStore;
@@ -494,8 +496,24 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde
494496
foreach ($stores as $name => $store) {
495497
$arguments = [
496498
new Reference($store['service']),
499+
new Definition(DistanceCalculator::class),
497500
];
498501

502+
if (\array_key_exists('cache_key', $store) && null !== $store['cache_key']) {
503+
$arguments[2] = $store['cache_key'];
504+
}
505+
506+
if (\array_key_exists('strategy', $store) && null !== $store['strategy']) {
507+
if (!$container->hasDefinition('ai.store.distance_calculator.'.$name)) {
508+
$distanceCalculatorDefinition = new Definition(DistanceCalculator::class);
509+
$distanceCalculatorDefinition->setArgument(0, DistanceStrategy::from($store['strategy']));
510+
511+
$container->setDefinition('ai.store.distance_calculator.'.$name, $distanceCalculatorDefinition);
512+
}
513+
514+
$arguments[1] = new Reference('ai.store.distance_calculator.'.$name);
515+
}
516+
499517
$definition = new Definition(CacheStore::class);
500518
$definition
501519
->addTag('ai.store')
@@ -577,9 +595,18 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde
577595

578596
if ('memory' === $type) {
579597
foreach ($stores as $name => $store) {
580-
$arguments = [
581-
$store['distance'],
582-
];
598+
$arguments = [];
599+
600+
if (\array_key_exists('strategy', $store) && null !== $store['strategy']) {
601+
if (!$container->hasDefinition('ai.store.distance_calculator.'.$name)) {
602+
$distanceCalculatorDefinition = new Definition(DistanceCalculator::class);
603+
$distanceCalculatorDefinition->setArgument(0, DistanceStrategy::from($store['strategy']));
604+
605+
$container->setDefinition('ai.store.distance_calculator.'.$name, $distanceCalculatorDefinition);
606+
}
607+
608+
$arguments[0] = new Reference('ai.store.distance_calculator.'.$name);
609+
}
583610

584611
$definition = new Definition(InMemoryStore::class);
585612
$definition

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
use Symfony\AI\AiBundle\AiBundle;
2020
use Symfony\Component\Config\Definition\Exception\InvalidConfigurationException;
2121
use Symfony\Component\DependencyInjection\ContainerBuilder;
22+
use Symfony\Component\DependencyInjection\Definition;
23+
use Symfony\Component\DependencyInjection\Reference;
2224

2325
#[CoversClass(AiBundle::class)]
2426
#[UsesClass(ContainerBuilder::class)]
27+
#[UsesClass(Definition::class)]
28+
#[UsesClass(Reference::class)]
2529
class AiBundleTest extends TestCase
2630
{
2731
#[DoesNotPerformAssertions]
@@ -105,6 +109,130 @@ public function testAgentsAsToolsCannotDefineService()
105109
]);
106110
}
107111

112+
public function testCacheStoreWithCustomKeyCanBeConfigured()
113+
{
114+
$container = $this->buildContainer([
115+
'ai' => [
116+
'store' => [
117+
'cache' => [
118+
'my_cache_store_with_custom_strategy' => [
119+
'service' => 'cache.system',
120+
'cache_key' => 'random',
121+
],
122+
],
123+
],
124+
],
125+
]);
126+
127+
$this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy'));
128+
$this->assertFalse($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy'));
129+
130+
$definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy');
131+
132+
$this->assertCount(3, $definition->getArguments());
133+
$this->assertInstanceOf(Reference::class, $definition->getArgument(0));
134+
$this->assertSame('cache.system', (string) $definition->getArgument(0));
135+
$this->assertSame('random', $definition->getArgument(2));
136+
}
137+
138+
public function testCacheStoreWithCustomStrategyCanBeConfigured()
139+
{
140+
$container = $this->buildContainer([
141+
'ai' => [
142+
'store' => [
143+
'cache' => [
144+
'my_cache_store_with_custom_strategy' => [
145+
'service' => 'cache.system',
146+
'strategy' => 'chebyshev',
147+
],
148+
],
149+
],
150+
],
151+
]);
152+
153+
$this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy'));
154+
$this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy'));
155+
156+
$definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy');
157+
158+
$this->assertCount(2, $definition->getArguments());
159+
$this->assertInstanceOf(Reference::class, $definition->getArgument(0));
160+
$this->assertSame('cache.system', (string) $definition->getArgument(0));
161+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
162+
$this->assertSame('ai.store.distance_calculator.my_cache_store_with_custom_strategy', (string) $definition->getArgument(1));
163+
}
164+
165+
public function testCacheStoreWithCustomStrategyAndKeyCanBeConfigured()
166+
{
167+
$container = $this->buildContainer([
168+
'ai' => [
169+
'store' => [
170+
'cache' => [
171+
'my_cache_store_with_custom_strategy' => [
172+
'service' => 'cache.system',
173+
'cache_key' => 'random',
174+
'strategy' => 'chebyshev',
175+
],
176+
],
177+
],
178+
],
179+
]);
180+
181+
$this->assertTrue($container->hasDefinition('ai.store.cache.my_cache_store_with_custom_strategy'));
182+
$this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_cache_store_with_custom_strategy'));
183+
184+
$definition = $container->getDefinition('ai.store.cache.my_cache_store_with_custom_strategy');
185+
186+
$this->assertCount(3, $definition->getArguments());
187+
$this->assertInstanceOf(Reference::class, $definition->getArgument(0));
188+
$this->assertSame('cache.system', (string) $definition->getArgument(0));
189+
$this->assertSame('random', $definition->getArgument(2));
190+
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
191+
$this->assertSame('ai.store.distance_calculator.my_cache_store_with_custom_strategy', (string) $definition->getArgument(1));
192+
}
193+
194+
public function testInMemoryStoreWithoutCustomStrategyCanBeConfigured()
195+
{
196+
$container = $this->buildContainer([
197+
'ai' => [
198+
'store' => [
199+
'memory' => [
200+
'my_memory_store_with_custom_strategy' => [],
201+
],
202+
],
203+
],
204+
]);
205+
206+
$this->assertTrue($container->hasDefinition('ai.store.memory.my_memory_store_with_custom_strategy'));
207+
208+
$definition = $container->getDefinition('ai.store.memory.my_memory_store_with_custom_strategy');
209+
$this->assertCount(0, $definition->getArguments());
210+
}
211+
212+
public function testInMemoryStoreWithCustomStrategyCanBeConfigured()
213+
{
214+
$container = $this->buildContainer([
215+
'ai' => [
216+
'store' => [
217+
'memory' => [
218+
'my_memory_store_with_custom_strategy' => [
219+
'strategy' => 'chebyshev',
220+
],
221+
],
222+
],
223+
],
224+
]);
225+
226+
$this->assertTrue($container->hasDefinition('ai.store.memory.my_memory_store_with_custom_strategy'));
227+
$this->assertTrue($container->hasDefinition('ai.store.distance_calculator.my_memory_store_with_custom_strategy'));
228+
229+
$definition = $container->getDefinition('ai.store.memory.my_memory_store_with_custom_strategy');
230+
231+
$this->assertCount(1, $definition->getArguments());
232+
$this->assertInstanceOf(Reference::class, $definition->getArgument(0));
233+
$this->assertSame('ai.store.distance_calculator.my_memory_store_with_custom_strategy', (string) $definition->getArgument(0));
234+
}
235+
108236
private function buildContainer(array $configuration): ContainerBuilder
109237
{
110238
$container = new ContainerBuilder();
@@ -205,6 +333,19 @@ private function getFullConfig(): array
205333
'my_cache_store' => [
206334
'service' => 'cache.system',
207335
],
336+
'my_cache_store_with_custom_key' => [
337+
'service' => 'cache.system',
338+
'cache_key' => 'bar',
339+
],
340+
'my_cache_store_with_custom_strategy' => [
341+
'service' => 'cache.system',
342+
'strategy' => 'chebyshev',
343+
],
344+
'my_cache_store_with_custom_strategy_and_custom_key' => [
345+
'service' => 'cache.system',
346+
'cache_key' => 'bar',
347+
'strategy' => 'chebyshev',
348+
],
208349
],
209350
'chroma_db' => [
210351
'my_chroma_store' => [
@@ -230,7 +371,7 @@ private function getFullConfig(): array
230371
],
231372
'memory' => [
232373
'my_memory_store' => [
233-
'distance' => 'cosine',
374+
'strategy' => 'cosine',
234375
],
235376
],
236377
'mongodb' => [

src/store/src/CacheStore.php

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ public function __construct(
3232
if (!interface_exists(CacheItemPoolInterface::class)) {
3333
throw new RuntimeException('For using the CacheStore as vector store, a PSR-6 cache implementation is required. Try running "composer require symfony/cache" or another PSR-6 compatible cache.');
3434
}
35+
36+
if (!interface_exists(CacheInterface::class)) {
37+
throw new RuntimeException('For using the CacheStore as vector store, a symfony/contracts cache implementation is required. Try running "composer require symfony/cache" or another symfony/contracts compatible cache.');
38+
}
3539
}
3640

3741
public function add(VectorDocument ...$documents): void

0 commit comments

Comments
 (0)