Skip to content

Commit 0d2b2e1

Browse files
Renrhafchr-hertel
authored andcommitted
[Platform] Add TokenOutputProcessor for Mistral
1 parent 1089e85 commit 0d2b2e1

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Bridge\Mistral;
13+
14+
use Symfony\AI\Agent\Output;
15+
use Symfony\AI\Agent\OutputProcessorInterface;
16+
use Symfony\AI\Platform\Response\StreamResponse;
17+
use Symfony\Contracts\HttpClient\ResponseInterface;
18+
19+
/**
20+
* @author Quentin Fahrner <[email protected]>
21+
*/
22+
final class TokenOutputProcessor implements OutputProcessorInterface
23+
{
24+
public function processOutput(Output $output): void
25+
{
26+
if ($output->response instanceof StreamResponse) {
27+
// Streams have to be handled manually as the tokens are part of the streamed chunks
28+
return;
29+
}
30+
31+
$rawResponse = $output->response->getRawResponse()?->getRawObject();
32+
if (!$rawResponse instanceof ResponseInterface) {
33+
return;
34+
}
35+
36+
$metadata = $output->response->getMetadata();
37+
38+
$metadata->add(
39+
'remaining_tokens_minute',
40+
(int) $rawResponse->getHeaders(false)['x-ratelimit-limit-tokens-minute'][0],
41+
);
42+
43+
$metadata->add(
44+
'remaining_tokens_month',
45+
(int) $rawResponse->getHeaders(false)['x-ratelimit-limit-tokens-month'][0],
46+
);
47+
48+
$content = $rawResponse->toArray(false);
49+
50+
if (!\array_key_exists('usage', $content)) {
51+
return;
52+
}
53+
54+
$metadata->add('prompt_tokens', $content['usage']['prompt_tokens'] ?? null);
55+
$metadata->add('completion_tokens', $content['usage']['completion_tokens'] ?? null);
56+
$metadata->add('total_tokens', $content['usage']['total_tokens'] ?? null);
57+
}
58+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Bridge\Mistral;
13+
14+
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\Small;
16+
use PHPUnit\Framework\Attributes\Test;
17+
use PHPUnit\Framework\Attributes\UsesClass;
18+
use PHPUnit\Framework\TestCase;
19+
use Symfony\AI\Agent\Output;
20+
use Symfony\AI\Platform\Bridge\Mistral\TokenOutputProcessor;
21+
use Symfony\AI\Platform\Message\MessageBagInterface;
22+
use Symfony\AI\Platform\Model;
23+
use Symfony\AI\Platform\Response\Metadata\Metadata;
24+
use Symfony\AI\Platform\Response\RawHttpResponse;
25+
use Symfony\AI\Platform\Response\ResponseInterface;
26+
use Symfony\AI\Platform\Response\StreamResponse;
27+
use Symfony\AI\Platform\Response\TextResponse;
28+
use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse;
29+
30+
#[CoversClass(TokenOutputProcessor::class)]
31+
#[UsesClass(Output::class)]
32+
#[UsesClass(TextResponse::class)]
33+
#[UsesClass(StreamResponse::class)]
34+
#[UsesClass(Metadata::class)]
35+
#[Small]
36+
final class TokenOutputProcessorTest extends TestCase
37+
{
38+
#[Test]
39+
public function itHandlesStreamResponsesWithoutProcessing(): void
40+
{
41+
$processor = new TokenOutputProcessor();
42+
$streamResponse = new StreamResponse((static function () { yield 'test'; })());
43+
$output = $this->createOutput($streamResponse);
44+
45+
$processor->processOutput($output);
46+
47+
$metadata = $output->response->getMetadata();
48+
self::assertCount(0, $metadata);
49+
}
50+
51+
#[Test]
52+
public function itDoesNothingWithoutRawResponse(): void
53+
{
54+
$processor = new TokenOutputProcessor();
55+
$textResponse = new TextResponse('test');
56+
$output = $this->createOutput($textResponse);
57+
58+
$processor->processOutput($output);
59+
60+
$metadata = $output->response->getMetadata();
61+
self::assertCount(0, $metadata);
62+
}
63+
64+
#[Test]
65+
public function itAddsRemainingTokensToMetadata(): void
66+
{
67+
$processor = new TokenOutputProcessor();
68+
$textResponse = new TextResponse('test');
69+
70+
$textResponse->setRawResponse($this->createRawResponse());
71+
72+
$output = $this->createOutput($textResponse);
73+
74+
$processor->processOutput($output);
75+
76+
$metadata = $output->response->getMetadata();
77+
self::assertCount(2, $metadata);
78+
self::assertSame(1000, $metadata->get('remaining_tokens_minute'));
79+
self::assertSame(1000000, $metadata->get('remaining_tokens_month'));
80+
}
81+
82+
#[Test]
83+
public function itAddsUsageTokensToMetadata(): void
84+
{
85+
$processor = new TokenOutputProcessor();
86+
$textResponse = new TextResponse('test');
87+
88+
$rawResponse = $this->createRawResponse([
89+
'usage' => [
90+
'prompt_tokens' => 10,
91+
'completion_tokens' => 20,
92+
'total_tokens' => 30,
93+
],
94+
]);
95+
96+
$textResponse->setRawResponse($rawResponse);
97+
98+
$output = $this->createOutput($textResponse);
99+
100+
$processor->processOutput($output);
101+
102+
$metadata = $output->response->getMetadata();
103+
self::assertCount(5, $metadata);
104+
self::assertSame(1000, $metadata->get('remaining_tokens_minute'));
105+
self::assertSame(1000000, $metadata->get('remaining_tokens_month'));
106+
self::assertSame(10, $metadata->get('prompt_tokens'));
107+
self::assertSame(20, $metadata->get('completion_tokens'));
108+
self::assertSame(30, $metadata->get('total_tokens'));
109+
}
110+
111+
#[Test]
112+
public function itHandlesMissingUsageFields(): void
113+
{
114+
$processor = new TokenOutputProcessor();
115+
$textResponse = new TextResponse('test');
116+
117+
$rawResponse = $this->createRawResponse([
118+
'usage' => [
119+
// Missing some fields
120+
'prompt_tokens' => 10,
121+
],
122+
]);
123+
124+
$textResponse->setRawResponse($rawResponse);
125+
126+
$output = $this->createOutput($textResponse);
127+
128+
$processor->processOutput($output);
129+
130+
$metadata = $output->response->getMetadata();
131+
self::assertCount(5, $metadata);
132+
self::assertSame(1000, $metadata->get('remaining_tokens_minute'));
133+
self::assertSame(1000000, $metadata->get('remaining_tokens_month'));
134+
self::assertSame(10, $metadata->get('prompt_tokens'));
135+
self::assertNull($metadata->get('completion_tokens'));
136+
self::assertNull($metadata->get('total_tokens'));
137+
}
138+
139+
private function createRawResponse(array $data = []): RawHttpResponse
140+
{
141+
$rawResponse = self::createStub(SymfonyHttpResponse::class);
142+
$rawResponse->method('getHeaders')->willReturn([
143+
'x-ratelimit-limit-tokens-minute' => ['1000'],
144+
'x-ratelimit-limit-tokens-month' => ['1000000'],
145+
]);
146+
147+
$rawResponse->method('toArray')->willReturn($data);
148+
149+
return new RawHttpResponse($rawResponse);
150+
}
151+
152+
private function createOutput(ResponseInterface $response): Output
153+
{
154+
return new Output(
155+
self::createStub(Model::class),
156+
$response,
157+
self::createStub(MessageBagInterface::class),
158+
[],
159+
);
160+
}
161+
}

0 commit comments

Comments
 (0)