diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index bb5289196..f57cd478a 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -11,6 +11,7 @@ namespace Symfony\Component\DependencyInjection\Loader\Configurator; +use Symfony\AI\Agent\AgentInterface; use Symfony\AI\Agent\StructuredOutput\AgentProcessor as StructureOutputProcessor; use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactory; use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface; @@ -23,6 +24,7 @@ use Symfony\AI\Agent\Toolbox\ToolResultConverter; use Symfony\AI\AiBundle\Command\AgentCallCommand; use Symfony\AI\AiBundle\Profiler\DataCollector; +use Symfony\AI\AiBundle\Profiler\TraceableAgent; use Symfony\AI\AiBundle\Profiler\TraceableToolbox; use Symfony\AI\AiBundle\Security\EventListener\IsGrantedToolAttributeListener; use Symfony\AI\Platform\Bridge\AiMlApi\ModelCatalog as AiMlApiModelCatalog; @@ -170,6 +172,12 @@ ->tag('kernel.event_listener') // profiler + ->set('ai.traceable_agent', TraceableAgent::class) + ->decorate(AgentInterface::class, priority: 5) + ->args([ + service('.inner'), + service('ai.data_collector'), + ]) ->set('ai.data_collector', DataCollector::class) ->args([ tagged_iterator('ai.traceable_platform'), diff --git a/src/ai-bundle/src/Profiler/DataCollector.php b/src/ai-bundle/src/Profiler/DataCollector.php index 546ecdabd..13a538d3d 100644 --- a/src/ai-bundle/src/Profiler/DataCollector.php +++ b/src/ai-bundle/src/Profiler/DataCollector.php @@ -18,11 +18,11 @@ use Symfony\Component\HttpFoundation\Request; use Symfony\Component\HttpFoundation\Response; use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface; +use Symfony\Component\VarDumper\Cloner\Data; /** * @author Christopher Hertel * - * @phpstan-import-type PlatformCallData from TraceablePlatform * @phpstan-import-type ToolCallData from TraceableToolbox */ final class DataCollector extends AbstractDataCollector implements LateDataCollectorInterface @@ -37,6 +37,11 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle */ private readonly array $toolboxes; + /** + * @var list + */ + private array $collectedCalls = []; + /** * @param TraceablePlatform[] $platforms * @param TraceableToolbox[] $toolboxes @@ -52,15 +57,55 @@ public function __construct( public function collect(Request $request, Response $response, ?\Throwable $exception = null): void { - $this->lateCollect(); } public function lateCollect(): void { + $platformCalls = []; + foreach ($this->platforms as $platform) { + $calls = $platform->calls; + foreach ($calls as $call) { + $result = $call['result']->await(); + if (isset($platform->resultCache[$result])) { + $call['result'] = $platform->resultCache[$result]; + } else { + $call['result'] = $result->getContent(); + } + + $call['model'] = $this->cloneVar($call['model']); + $call['input'] = $this->cloneVar($call['input']); + $call['options'] = $this->cloneVar($call['options']); + $call['result'] = $this->cloneVar($call['result']); + + $platformCalls[] = $call; + } + } + + $toolCalls = []; + foreach ($this->toolboxes as $toolbox) { + foreach ($toolbox->calls as $call) { + $call['call'] = $this->cloneVar($call['call']); + $call['result'] = $this->cloneVar($call['result']); + $toolCalls[] = $call; + } + } + $this->data = [ 'tools' => $this->defaultToolBox->getTools(), - 'platform_calls' => array_merge(...array_map($this->awaitCallResults(...), $this->platforms)), - 'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)), + 'platform_calls' => $platformCalls, + 'tool_calls' => $toolCalls, + 'agent_calls' => $this->collectedCalls, + ]; + } + + public function collectAgentCall(string $method, float $duration, mixed $input, mixed $result, ?\Throwable $error): void + { + $this->collectedCalls[] = [ + 'method' => $method, + 'duration' => $duration, + 'input' => $this->cloneVar($input), + 'result' => $this->cloneVar($result), + 'error' => $this->cloneVar($error), ]; } @@ -70,7 +115,12 @@ public static function getTemplate(): string } /** - * @return PlatformCallData[] + * @return array{ + * model: Data, + * input: Data, + * options: Data, + * result: Data + * }[] */ public function getPlatformCalls(): array { @@ -94,28 +144,16 @@ public function getToolCalls(): array } /** - * @return array{ - * model: string, - * input: array|string|object, - * options: array, - * result: string|iterable|object|null - * }[] + * @return list */ - private function awaitCallResults(TraceablePlatform $platform): array + public function getAgentCalls(): array { - $calls = $platform->calls; - foreach ($calls as $key => $call) { - $result = $call['result']->await(); - - if (isset($platform->resultCache[$result])) { - $call['result'] = $platform->resultCache[$result]; - } else { - $call['result'] = $result->getContent(); - } - - $calls[$key] = $call; - } + return $this->data['agent_calls'] ?? []; + } - return $calls; + public function reset(): void + { + $this->data = []; + $this->collectedCalls = []; } } diff --git a/src/ai-bundle/src/Profiler/TraceableAgent.php b/src/ai-bundle/src/Profiler/TraceableAgent.php new file mode 100644 index 000000000..a24061b63 --- /dev/null +++ b/src/ai-bundle/src/Profiler/TraceableAgent.php @@ -0,0 +1,60 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AiBundle\Profiler; + +use Symfony\AI\Agent\AgentInterface; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Result\ResultInterface; +use Symfony\Contracts\Service\ResetInterface; + +final class TraceableAgent implements AgentInterface, ResetInterface +{ + public function __construct( + private readonly AgentInterface $decorated, + private readonly DataCollector $collector, + ) { + } + + public function call(MessageBag $messages, array $options = []): ResultInterface + { + $startTime = microtime(true); + $error = null; + $response = null; + + try { + return $response = $this->decorated->call($messages, $options); + } catch (\Throwable $e) { + $error = $e; + throw $e; + } finally { + $this->collector->collectAgentCall( + 'call', + microtime(true) - $startTime, + $messages, + $response, + $error + ); + } + } + + public function reset(): void + { + if ($this->decorated instanceof ResetInterface) { + $this->decorated->reset(); + } + } + + public function getName(): string + { + return 'TraceableAgent'; + } +} diff --git a/src/ai-bundle/templates/data_collector.html.twig b/src/ai-bundle/templates/data_collector.html.twig index ec064e634..4e18cc30c 100644 --- a/src/ai-bundle/templates/data_collector.html.twig +++ b/src/ai-bundle/templates/data_collector.html.twig @@ -1,5 +1,16 @@ {% extends '@WebProfiler/Profiler/layout.html.twig' %} +{% block head %} + {{ parent() }} + +{% endblock %} + {% block toolbar %} {% if collector.platformCalls|length > 0 %} {% set icon %} @@ -12,6 +23,10 @@ {% set text %}
+
+ Agent Calls + {{ collector.agentCalls|length }} +
Configured Platforms 1 @@ -43,22 +58,45 @@ {% endblock %} -{% macro tool_calls(toolCalls) %} - Tool call{{ toolCalls|length > 1 ? 's' }}: -
    - {% for toolCall in toolCalls %} -
  1. - {{ toolCall.name }}({{ toolCall.arguments|map((value, key) => "#{key}: #{value|json_encode}")|join(', ') }}) - (ID: {{ toolCall.id }}) -
  2. - {% endfor %} -
-{% endmacro %} - {% block panel %}

Symfony AI

+ +

Agent Calls

+ {% if collector.agentCalls|length %} + + + + + + + + + + + + {% for call in collector.agentCalls %} + + + + + + + + {% endfor %} + +
MethodDurationInputResultError
{{ call.method }}{{ (call.duration * 1000)|round(2) }} ms{{ profiler_dump(call.input) }}{{ profiler_dump(call.result) }}{{ profiler_dump(call.error) }}
+ {% else %} +
+

No agent calls were made.

+
+ {% endif %} +
+
+ {{ collector.agentCalls|length }} + Agent Calls +
1 Platforms @@ -80,6 +118,7 @@
+

Platform Calls

{% if collector.platformCalls|length %}
@@ -94,78 +133,22 @@ - - Model - {{ constant('class', call.model) }} (Version: {{ call.model.name }}) - - - Input - - {% if call.input.messages is defined %}{# expect MessageBag #} -
    - {% for message in call.input.messages %} -
  1. - {{ message.role.value|title }}: - {% if 'assistant' == message.role.value and message.hasToolCalls%} - {{ _self.tool_calls(message.toolCalls) }} - {% elseif 'tool' == message.role.value %} - Result of tool call with ID {{ message.toolCall.id }}
    - {{ message.content|nl2br }} - {% elseif 'user' == message.role.value %} - {% for item in message.content %} - {% if item.text is defined %} - {{ item.text|nl2br }} - {% else %} - - {% endif %} - {% endfor %} - {% else %} - {{ message.content|nl2br }} - {% endif %} -
  2. - {% endfor %} -
- {% else %} - {{ dump(call.input) }} - {% endif %} - - - - Options - -
    - {% for key, value in call.options %} - {% if key == 'tools' %} -
  • {{ key }}: -
      - {% for tool in value %} -
    • {{ tool.name }}
    • - {% endfor %} -
    -
  • - {% else %} -
  • {{ key }}: {{ dump(value) }}
  • - {% endif %} - {% endfor %} -
- - - - Result - - {% if call.input.messages is defined and call.result is iterable %}{# expect array of ToolCall #} - {{ _self.tool_calls(call.result) }} - {% elseif call.result is iterable %}{# expect array of Vectors #} -
    - {% for vector in call.result %} -
  1. Vector with {{ vector.dimensions }} dimensions
  2. - {% endfor %} -
- {% else %} - {{ call.result }} - {% endif %} - - + + Model + {{ profiler_dump(call.model) }} + + + Input + {{ profiler_dump(call.input) }} + + + Options + {{ profiler_dump(call.options) }} + + + Result + {{ profiler_dump(call.result) }} + {% endfor %} @@ -190,27 +173,14 @@ - {% for tool in collector.tools %} - - {{ tool.name }} - {{ tool.description }} - {{ tool.reference.class }}::{{ tool.reference.method }} - - {% if tool.parameters %} -
    - {% for name, parameter in tool.parameters.properties %} -
  • - {{ name }} ({{ parameter.type is iterable ? parameter.type|join(', ') : parameter.type }})
    - {{ parameter.description|default() }} -
  • - {% endfor %} -
- {% else %} - none - {% endif %} - - - {% endfor %} + {% for tool in collector.tools %} + + {{ tool.name }} + {{ tool.description }} + {{ profiler_dump(tool.reference) }} + {{ profiler_dump(tool.parameters) }} + + {% endfor %} {% else %} @@ -235,11 +205,15 @@ Arguments - {{ dump(call.call.arguments) }} +
{{ profiler_dump(call.call.arguments) }}
+ + + Call + {{ profiler_dump(call.call) }} Result - {{ dump(call.result) }} + {{ profiler_dump(call.result) }} diff --git a/src/ai-bundle/tests/Profiler/DataCollectorTest.php b/src/ai-bundle/tests/Profiler/DataCollectorTest.php index bed7ded34..c4ea8b562 100644 --- a/src/ai-bundle/tests/Profiler/DataCollectorTest.php +++ b/src/ai-bundle/tests/Profiler/DataCollectorTest.php @@ -42,7 +42,7 @@ public function testCollectsDataForNonStreamingResponse() $dataCollector->lateCollect(); $this->assertCount(1, $dataCollector->getPlatformCalls()); - $this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']); + $this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']->getValue(true)); } public function testCollectsDataForStreamingResponse() @@ -66,6 +66,6 @@ public function testCollectsDataForStreamingResponse() $dataCollector->lateCollect(); $this->assertCount(1, $dataCollector->getPlatformCalls()); - $this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']); + $this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']->getValue(true)); } }