Skip to content

Commit 2206f1e

Browse files
valtzutryvin
authored andcommitted
feat: add Google Gemini tool support (#331)
Add tool support to Google Gemini. Extracted from #320 with updates after #326 Co-authored-by: Vin Souza <[email protected]>
1 parent 500970c commit 2206f1e

File tree

10 files changed

+532
-15
lines changed

10 files changed

+532
-15
lines changed

examples/google/toolcall.php

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Agent\Toolbox\AgentProcessor;
14+
use Symfony\AI\Agent\Toolbox\Tool\Clock;
15+
use Symfony\AI\Agent\Toolbox\Toolbox;
16+
use Symfony\AI\Platform\Bridge\Google\Gemini;
17+
use Symfony\AI\Platform\Bridge\Google\PlatformFactory;
18+
use Symfony\AI\Platform\Message\Message;
19+
use Symfony\AI\Platform\Message\MessageBag;
20+
use Symfony\Component\Dotenv\Dotenv;
21+
22+
require_once dirname(__DIR__, 2).'/vendor/autoload.php';
23+
(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env');
24+
25+
if (empty($_ENV['GOOGLE_API_KEY'])) {
26+
echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL;
27+
exit(1);
28+
}
29+
30+
$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
31+
$llm = new Gemini(Gemini::GEMINI_2_FLASH);
32+
33+
$toolbox = Toolbox::create(new Clock());
34+
$processor = new AgentProcessor($toolbox);
35+
$chain = new Agent($platform, $llm, [$processor], [$processor]);
36+
37+
$messages = new MessageBag(Message::ofUser('What time is it?'));
38+
$response = $chain->call($messages);
39+
40+
echo $response->getContent().\PHP_EOL;

src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,12 @@
1515
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
1616
use Symfony\AI\Platform\Message\AssistantMessage;
1717
use Symfony\AI\Platform\Model;
18-
use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface;
19-
use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait;
2018

2119
/**
2220
* @author Christopher Hertel <[email protected]>
2321
*/
24-
final class AssistantMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface
22+
final class AssistantMessageNormalizer extends ModelContractNormalizer
2523
{
26-
use NormalizerAwareTrait;
27-
2824
protected function supportedDataClass(): string
2925
{
3026
return AssistantMessage::class;
@@ -42,8 +38,23 @@ protected function supportsModel(Model $model): bool
4238
*/
4339
public function normalize(mixed $data, ?string $format = null, array $context = []): array
4440
{
45-
return [
46-
['text' => $data->content],
47-
];
41+
$normalized = [];
42+
43+
if (isset($data->content)) {
44+
$normalized['text'] = $data->content;
45+
}
46+
47+
if (isset($data->toolCalls[0])) {
48+
$normalized['functionCall'] = [
49+
'id' => $data->toolCalls[0]->id,
50+
'name' => $data->toolCalls[0]->name,
51+
];
52+
53+
if ($data->toolCalls[0]->arguments) {
54+
$normalized['functionCall']['args'] = $data->toolCalls[0]->arguments;
55+
}
56+
}
57+
58+
return [$normalized];
4859
}
4960
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Bridge\Google\Contract;
13+
14+
use Symfony\AI\Platform\Bridge\Google\Gemini;
15+
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
16+
use Symfony\AI\Platform\Message\ToolCallMessage;
17+
use Symfony\AI\Platform\Model;
18+
19+
/**
20+
* @author Valtteri R <[email protected]>
21+
*/
22+
final class ToolCallMessageNormalizer extends ModelContractNormalizer
23+
{
24+
protected function supportedDataClass(): string
25+
{
26+
return ToolCallMessage::class;
27+
}
28+
29+
protected function supportsModel(Model $model): bool
30+
{
31+
return $model instanceof Gemini;
32+
}
33+
34+
/**
35+
* @param ToolCallMessage $data
36+
*
37+
* @return array{
38+
* functionResponse: array{
39+
* id: string,
40+
* name: string,
41+
* response: array<int|string, mixed>
42+
* }
43+
* }[]
44+
*/
45+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
46+
{
47+
$responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content;
48+
49+
return [[
50+
'functionResponse' => array_filter([
51+
'id' => $data->toolCall->id,
52+
'name' => $data->toolCall->name,
53+
'response' => \is_array($responseContent) ? $responseContent : [
54+
'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses.
55+
],
56+
]),
57+
]];
58+
}
59+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Bridge\Google\Contract;
13+
14+
use Symfony\AI\Platform\Bridge\Google\Gemini;
15+
use Symfony\AI\Platform\Contract\JsonSchema\Factory;
16+
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
17+
use Symfony\AI\Platform\Model;
18+
use Symfony\AI\Platform\Tool\Tool;
19+
20+
/**
21+
* @author Valtteri R <[email protected]>
22+
*
23+
* @phpstan-import-type JsonSchema from Factory
24+
*/
25+
final class ToolNormalizer extends ModelContractNormalizer
26+
{
27+
protected function supportedDataClass(): string
28+
{
29+
return Tool::class;
30+
}
31+
32+
protected function supportsModel(Model $model): bool
33+
{
34+
return $model instanceof Gemini;
35+
}
36+
37+
/**
38+
* @param Tool $data
39+
*
40+
* @return array{
41+
* functionDeclarations: array{
42+
* name: string,
43+
* description: string,
44+
* parameters: JsonSchema|array{type: 'object'}
45+
* }[]
46+
* }
47+
*/
48+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
49+
{
50+
$parameters = $data->parameters;
51+
unset($parameters['additionalProperties']);
52+
53+
return [
54+
'functionDeclarations' => [
55+
[
56+
'description' => $data->description,
57+
'name' => $data->name,
58+
'parameters' => $parameters,
59+
],
60+
],
61+
];
62+
}
63+
}

src/platform/src/Bridge/Google/Gemini.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options =
3434
Capability::INPUT_MESSAGES,
3535
Capability::INPUT_IMAGE,
3636
Capability::OUTPUT_STREAMING,
37+
Capability::TOOL_CALLING,
3738
];
3839

3940
parent::__construct($name, $capabilities, $options);

src/platform/src/Bridge/Google/ModelHandler.php

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414
use Symfony\AI\Platform\Exception\RuntimeException;
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\ModelClientInterface;
17+
use Symfony\AI\Platform\Response\Choice;
18+
use Symfony\AI\Platform\Response\ChoiceResponse;
1719
use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse;
1820
use Symfony\AI\Platform\Response\StreamResponse;
1921
use Symfony\AI\Platform\Response\TextResponse;
22+
use Symfony\AI\Platform\Response\ToolCall;
23+
use Symfony\AI\Platform\Response\ToolCallResponse;
2024
use Symfony\AI\Platform\ResponseConverterInterface;
2125
use Symfony\Component\HttpClient\EventSourceHttpClient;
2226
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
@@ -59,6 +63,12 @@ public function request(Model $model, array|string $payload, array $options = []
5963

6064
$generationConfig = ['generationConfig' => $options];
6165
unset($generationConfig['generationConfig']['stream']);
66+
unset($generationConfig['generationConfig']['tools']);
67+
68+
if (isset($options['tools'])) {
69+
$generationConfig['tools'] = $options['tools'];
70+
unset($options['tools']);
71+
}
6272

6373
return $this->httpClient->request('POST', $url, [
6474
'headers' => [
@@ -83,11 +93,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe
8393

8494
$data = $response->toArray();
8595

86-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
96+
if (!isset($data['candidates'][0]['content']['parts'][0])) {
8797
throw new RuntimeException('Response does not contain any content');
8898
}
8999

90-
return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
100+
/** @var Choice[] $choices */
101+
$choices = array_map($this->convertChoice(...), $data['candidates']);
102+
103+
if (1 !== \count($choices)) {
104+
return new ChoiceResponse(...$choices);
105+
}
106+
107+
if ($choices[0]->hasToolCall()) {
108+
return new ToolCallResponse(...$choices[0]->getToolCalls());
109+
}
110+
111+
return new TextResponse($choices[0]->getContent());
91112
}
92113

93114
private function convertStream(ResponseInterface $response): \Generator
@@ -121,12 +142,68 @@ private function convertStream(ResponseInterface $response): \Generator
121142
throw new RuntimeException('Failed to decode JSON response', 0, $e);
122143
}
123144

124-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
145+
/** @var Choice[] $choices */
146+
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);
147+
148+
if (!$choices) {
125149
continue;
126150
}
127151

128-
yield $data['candidates'][0]['content']['parts'][0]['text'];
152+
if (1 !== \count($choices)) {
153+
yield new ChoiceResponse(...$choices);
154+
continue;
155+
}
156+
157+
if ($choices[0]->hasToolCall()) {
158+
yield new ToolCallResponse(...$choices[0]->getToolCalls());
159+
}
160+
161+
if ($choices[0]->hasContent()) {
162+
yield $choices[0]->getContent();
163+
}
129164
}
130165
}
131166
}
167+
168+
/**
169+
* @param array{
170+
* finishReason?: string,
171+
* content: array{
172+
* parts: array{
173+
* functionCall?: array{
174+
* id: string,
175+
* name: string,
176+
* args: mixed[]
177+
* },
178+
* text?: string
179+
* }[]
180+
* }
181+
* } $choice
182+
*/
183+
private function convertChoice(array $choice): Choice
184+
{
185+
$contentPart = $choice['content']['parts'][0] ?? [];
186+
187+
if (isset($contentPart['functionCall'])) {
188+
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
189+
}
190+
191+
if (isset($contentPart['text'])) {
192+
return new Choice($contentPart['text']);
193+
}
194+
195+
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason']));
196+
}
197+
198+
/**
199+
* @param array{
200+
* id: string,
201+
* name: string,
202+
* args: mixed[]
203+
* } $toolCall
204+
*/
205+
private function convertToolCall(array $toolCall): ToolCall
206+
{
207+
return new ToolCall($toolCall['id'], $toolCall['name'], $toolCall['args']);
208+
}
132209
}

src/platform/src/Bridge/Google/PlatformFactory.php

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer;
1515
use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer;
16+
use Symfony\AI\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer;
17+
use Symfony\AI\Platform\Bridge\Google\Contract\ToolNormalizer;
1618
use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer;
1719
use Symfony\AI\Platform\Contract;
1820
use Symfony\AI\Platform\Platform;
@@ -35,6 +37,8 @@ public static function create(
3537
return new Platform([$responseHandler], [$responseHandler], Contract::create(
3638
new AssistantMessageNormalizer(),
3739
new MessageBagNormalizer(),
40+
new ToolNormalizer(),
41+
new ToolCallMessageNormalizer(),
3842
new UserMessageNormalizer(),
3943
));
4044
}

0 commit comments

Comments
 (0)