diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 04c391f7..580a8668 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -27,13 +27,13 @@ use Symfony\AI\AiBundle\Security\Attribute\IsGrantedTool; use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory; use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory; +use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory; use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory; use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory; use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; -use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelClientInterface; use Symfony\AI\Platform\Platform; @@ -56,6 +56,7 @@ use Symfony\AI\Store\StoreInterface; use Symfony\AI\Store\VectorStoreInterface; use Symfony\Component\Config\Definition\Configurator\DefinitionConfigurator; +use Symfony\Component\DependencyInjection\Attribute\Target; use Symfony\Component\DependencyInjection\ChildDefinition; use Symfony\Component\DependencyInjection\ContainerBuilder; use Symfony\Component\DependencyInjection\ContainerInterface; @@ -107,9 +108,6 @@ public function loadExtension(array $config, ContainerConfigurator $container, C foreach ($config['agent'] as $agentName => $agent) { $this->processAgentConfig($agentName, $agent, $builder); } - if (1 === \count($config['agent']) && isset($agentName)) { - $builder->setAlias(AgentInterface::class, 'ai.agent.'.$agentName); - } foreach ($config['store'] ?? [] as $type => $store) { $this->processStoreConfig($type, $store, $builder); @@ -460,6 +458,7 @@ private function processAgentConfig(string $name, array $config, ContainerBuilde ; $container->setDefinition('ai.agent.'.$name, $agentDefinition); + $container->registerAliasForArgument('ai.agent.'.$name, AgentInterface::class, (new Target($name.'Agent'))->getParsedName()); } /** diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index ba0b1993..69305d77 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -30,6 +30,21 @@ public function testExtensionLoadDoesNotThrow() $this->buildContainer($this->getFullConfig()); } + public function testInjectionAgentAliasIsRegistered() + { + $container = $this->buildContainer([ + 'ai' => [ + 'agent' => [ + 'my_agent' => [ + 'model' => ['class' => 'Symfony\AI\Platform\Bridge\OpenAi\Gpt'], + ], + ], + ], + ]); + + $this->assertTrue($container->hasAlias('Symfony\AI\Agent\AgentInterface $myAgentAgent')); + } + #[TestWith([true], 'enabled')] #[TestWith([false], 'disabled')] public function testFaultTolerantAgentSpecificToolbox(bool $enabled)