Skip to content

Commit b3ebfca

Browse files
feat(platform): Add support for Google vertex AI
- Adds tests to verify the behavior
1 parent f6ad490 commit b3ebfca

17 files changed

+793
-19
lines changed

src/platform/src/Bridge/VertexAi/Contract/AssistantMessageNormalizer.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public function normalize(mixed $data, ?string $format = null, array $context =
3939
$normalized = [];
4040

4141
if (isset($data->content)) {
42-
$normalized['text'] = $data->content;
42+
$normalized[] = ['text' => $data->content];
4343
}
4444

4545
if (isset($data->toolCalls[0])) {
@@ -52,7 +52,7 @@ public function normalize(mixed $data, ?string $format = null, array $context =
5252
}
5353
}
5454

55-
return [$normalized];
55+
return $normalized;
5656
}
5757

5858
protected function supportedDataClass(): string

src/platform/src/Bridge/VertexAi/Contract/MessageBagNormalizer.php

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
use Symfony\AI\Platform\Bridge\VertexAi\Gemini\Model;
1515
use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer;
16+
use Symfony\AI\Platform\Message\MessageBag;
1617
use Symfony\AI\Platform\Message\MessageBagInterface;
1718
use Symfony\AI\Platform\Message\Role;
1819
use Symfony\AI\Platform\Model as BaseModel;
@@ -53,7 +54,7 @@ public function normalize(mixed $data, ?string $format = null, array $context =
5354
foreach ($data->withoutSystemMessage()->getMessages() as $message) {
5455
$requestData['contents'][] = [
5556
'role' => $message->getRole()->equals(Role::Assistant) ? 'model' : 'user',
56-
'parts' => [['text' => $this->normalizer->normalize($message, $format, $context)]],
57+
'parts' => $this->normalizer->normalize($message, $format, $context),
5758
];
5859
}
5960

@@ -62,7 +63,7 @@ public function normalize(mixed $data, ?string $format = null, array $context =
6263

6364
protected function supportedDataClass(): string
6465
{
65-
return MessageBagInterface::class;
66+
return MessageBag::class;
6667
}
6768

6869
protected function supportsModel(BaseModel $model): bool

src/platform/src/Bridge/VertexAi/Contract/ToolNormalizer.php

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ public function normalize(mixed $data, ?string $format = null, array $context =
4141
'name' => $data->name,
4242
'description' => $data->description,
4343
'parameters' => $parameters,
44-
'response' => $parameters,
4544
];
4645
}
4746

src/platform/src/Bridge/VertexAi/Contract/UserMessageNormalizer.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ final class UserMessageNormalizer extends ModelContractNormalizer
2626
/**
2727
* @param UserMessage $data
2828
*
29-
* @return list<array{inline_data?: array{mime_type: string, data: string}}>
29+
* @return list<array{inlineData?: array{mimeType: string, data: string}}>
3030
*/
3131
public function normalize(mixed $data, ?string $format = null, array $context = []): array
3232
{

src/platform/src/Bridge/VertexAi/Embeddings/Model.php

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class Model extends BaseModel
2727
public const TEXT_MULTILINGUAL_EMBEDDING_002 = 'text-multilingual-embedding-002';
2828

2929
/**
30-
* @param array{task_type?: TaskType} $options
3130
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api for various options
3231
*/
3332
public function __construct(string $name = self::GEMINI_EMBEDDING_001, array $options = [])

src/platform/src/Bridge/VertexAi/Embeddings/ModelClient.php

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,12 @@ public function request(BaseModel $model, array|string $payload, array $options
5050

5151
$modelOptions = $model->getOptions();
5252

53-
if (isset($modelOptions['task_type']) && $modelOptions['task_type'] instanceof TaskType) {
54-
$modelOptions['task_type'] = $modelOptions['task_type']->value;
55-
}
56-
5753
$payload = [
5854
'instances' => array_map(
5955
static fn (string $text) => [
6056
'content' => ['parts' => [['text' => $text]]],
6157
'title' => $options['title'] ?? null,
62-
'task_type' => isset($modelOptions['task_type']) ? $modelOptions['task_type']->value : null,
58+
'task_type' => $modelOptions['task_type'] ?? null,
6359
],
6460
\is_array($payload) ? $payload : [$payload],
6561
),

src/platform/src/Bridge/VertexAi/Gemini/Model.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ final class Model extends BaseModel
2727

2828
/**
2929
* @param array<string, mixed> $options The default options for the model usage
30+
*
3031
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference for more details
3132
*/
3233
public function __construct(string $name = self::GEMINI_2_5_PRO, array $options = [])

src/platform/src/Bridge/VertexAi/Gemini/ResultConverter.php

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,21 @@ private function convertStream(HttpResponse $result): \Generator
115115
* text?: string
116116
* }[]
117117
* }
118-
* }[] $choices
118+
* } $choices
119119
*/
120120
private function convertChoice(array $choices): ToolCallResult|TextResult
121121
{
122-
$contentPart = $choices[0]['content']['parts'][0] ?? [];
122+
$content = $choices['content']['parts'][0] ?? [];
123123

124-
if (isset($contentPart['functionCall'])) {
125-
return new ToolCallResult($this->convertToolCall($contentPart['functionCall']));
124+
if (isset($content['functionCall'])) {
125+
return new ToolCallResult($this->convertToolCall($content['functionCall']));
126126
}
127127

128-
if (isset($contentPart['text'])) {
129-
return new TextResult($contentPart['text']);
128+
if (isset($content['text'])) {
129+
return new TextResult($content['text']);
130130
}
131131

132-
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choices[0]['finishReason']));
132+
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choices['finishReason']));
133133
}
134134

135135
/**
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Tests\Bridge\VertexAi\Contract;
13+
14+
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\DataProvider;
16+
use PHPUnit\Framework\Attributes\Small;
17+
use PHPUnit\Framework\Attributes\UsesClass;
18+
use PHPUnit\Framework\TestCase;
19+
use Symfony\AI\Platform\Bridge\VertexAi\Contract\AssistantMessageNormalizer;
20+
use Symfony\AI\Platform\Bridge\VertexAi\Gemini\Model;
21+
use Symfony\AI\Platform\Contract;
22+
use Symfony\AI\Platform\Message\AssistantMessage;
23+
use Symfony\AI\Platform\Model as BaseModel;
24+
use Symfony\AI\Platform\Result\ToolCall;
25+
26+
#[Small]
27+
#[CoversClass(AssistantMessageNormalizer::class)]
28+
#[UsesClass(Model::class)]
29+
#[UsesClass(AssistantMessage::class)]
30+
#[UsesClass(BaseModel::class)]
31+
#[UsesClass(ToolCall::class)]
32+
final class AssistantMessageNormalizerTest extends TestCase
33+
{
34+
public function testSupportsNormalization()
35+
{
36+
$normalizer = new AssistantMessageNormalizer();
37+
38+
$this->assertTrue($normalizer->supportsNormalization(new AssistantMessage('Hello'), context: [
39+
Contract::CONTEXT_MODEL => new Model(),
40+
]));
41+
$this->assertFalse($normalizer->supportsNormalization('not an assistant message'));
42+
}
43+
44+
public function testGetSupportedTypes()
45+
{
46+
$normalizer = new AssistantMessageNormalizer();
47+
48+
$this->assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null));
49+
}
50+
51+
#[DataProvider('normalizeDataProvider')]
52+
public function testNormalize(AssistantMessage $message, array $expectedOutput)
53+
{
54+
$normalizer = new AssistantMessageNormalizer();
55+
56+
$normalized = $normalizer->normalize($message);
57+
58+
$this->assertSame($expectedOutput, $normalized);
59+
}
60+
61+
/**
62+
* @return iterable<string, array{
63+
* AssistantMessage,
64+
* array{text?: string, functionCall?: array{name: string, args?: mixed}}[]
65+
* }>
66+
*/
67+
public static function normalizeDataProvider(): iterable
68+
{
69+
yield 'assistant message' => [
70+
new AssistantMessage('Great to meet you. What would you like to know?'),
71+
[['text' => 'Great to meet you. What would you like to know?']],
72+
];
73+
yield 'function call' => [
74+
new AssistantMessage(toolCalls: [new ToolCall('name1', 'name1', ['arg1' => '123'])]),
75+
['functionCall' => ['name' => 'name1', 'args' => ['arg1' => '123']]],
76+
];
77+
yield 'function call without parameters' => [
78+
new AssistantMessage(toolCalls: [new ToolCall('name1', 'name1')]),
79+
['functionCall' => ['name' => 'name1']],
80+
];
81+
}
82+
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Tests\Bridge\VertexAi\Contract;
13+
14+
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\DataProvider;
16+
use PHPUnit\Framework\Attributes\Medium;
17+
use PHPUnit\Framework\Attributes\UsesClass;
18+
use PHPUnit\Framework\TestCase;
19+
use Symfony\AI\Platform\Bridge\VertexAi\Contract\AssistantMessageNormalizer;
20+
use Symfony\AI\Platform\Bridge\VertexAi\Contract\MessageBagNormalizer;
21+
use Symfony\AI\Platform\Bridge\VertexAi\Contract\UserMessageNormalizer;
22+
use Symfony\AI\Platform\Bridge\VertexAi\Gemini\Model;
23+
use Symfony\AI\Platform\Contract;
24+
use Symfony\AI\Platform\Message\AssistantMessage;
25+
use Symfony\AI\Platform\Message\Content\Image;
26+
use Symfony\AI\Platform\Message\Message;
27+
use Symfony\AI\Platform\Message\MessageBag;
28+
use Symfony\AI\Platform\Message\UserMessage;
29+
use Symfony\AI\Platform\Model as BaseModel;
30+
use Symfony\Component\Serializer\Normalizer\NormalizerInterface;
31+
32+
#[Medium]
33+
#[CoversClass(MessageBagNormalizer::class)]
34+
#[CoversClass(UserMessageNormalizer::class)]
35+
#[CoversClass(AssistantMessageNormalizer::class)]
36+
#[UsesClass(BaseModel::class)]
37+
#[UsesClass(Model::class)]
38+
#[UsesClass(MessageBag::class)]
39+
#[UsesClass(UserMessage::class)]
40+
#[UsesClass(AssistantMessage::class)]
41+
final class MessageBagNormalizerTest extends TestCase
42+
{
43+
public function testSupportsNormalization()
44+
{
45+
$normalizer = new MessageBagNormalizer();
46+
47+
$this->assertTrue($normalizer->supportsNormalization(new MessageBag(), context: [
48+
Contract::CONTEXT_MODEL => new Model(),
49+
]));
50+
$this->assertFalse($normalizer->supportsNormalization('not a message bag'));
51+
}
52+
53+
public function testGetSupportedTypes()
54+
{
55+
$normalizer = new MessageBagNormalizer();
56+
57+
$expected = [
58+
MessageBag::class => true,
59+
];
60+
61+
$this->assertSame($expected, $normalizer->getSupportedTypes(null));
62+
}
63+
64+
#[DataProvider('provideMessageBagData')]
65+
public function testNormalize(MessageBag $bag, array $expected)
66+
{
67+
$normalizer = new MessageBagNormalizer();
68+
69+
// Set up the inner normalizers
70+
$userMessageNormalizer = new UserMessageNormalizer();
71+
$assistantMessageNormalizer = new AssistantMessageNormalizer();
72+
73+
// Mock a normalizer that delegates to the appropriate concrete normalizer
74+
$mockNormalizer = $this->createMock(NormalizerInterface::class);
75+
$mockNormalizer->method('normalize')
76+
->willReturnCallback(function ($message) use ($userMessageNormalizer, $assistantMessageNormalizer): ?array {
77+
if ($message instanceof UserMessage) {
78+
return $userMessageNormalizer->normalize($message);
79+
}
80+
if ($message instanceof AssistantMessage) {
81+
return $assistantMessageNormalizer->normalize($message);
82+
}
83+
84+
return null;
85+
});
86+
87+
$normalizer->setNormalizer($mockNormalizer);
88+
89+
$normalized = $normalizer->normalize($bag);
90+
91+
$this->assertEquals($expected, $normalized);
92+
}
93+
94+
/**
95+
* @return iterable<array{0: MessageBag, 1: array}>
96+
*/
97+
public static function provideMessageBagData(): iterable
98+
{
99+
yield 'simple text' => [
100+
new MessageBag(Message::ofUser('Write a story about a magic backpack.')),
101+
[
102+
'contents' => [
103+
['role' => 'user', 'parts' => [['text' => 'Write a story about a magic backpack.']]],
104+
],
105+
],
106+
];
107+
108+
yield 'text with image' => [
109+
new MessageBag(
110+
Message::ofUser('Tell me about this instrument', Image::fromFile(\dirname(__DIR__, 6).'/fixtures/image.jpg'))
111+
),
112+
[
113+
'contents' => [
114+
['role' => 'user', 'parts' => [
115+
['text' => 'Tell me about this instrument'],
116+
['inlineData' => ['mimeType' => 'image/jpeg', 'data' => base64_encode(file_get_contents(\dirname(__DIR__, 6).'/fixtures/image.jpg'))]],
117+
]],
118+
],
119+
],
120+
];
121+
122+
yield 'with assistant message' => [
123+
new MessageBag(
124+
Message::ofUser('Hello'),
125+
Message::ofAssistant('Great to meet you. What would you like to know?'),
126+
Message::ofUser('I have two dogs in my house. How many paws are in my house?'),
127+
),
128+
[
129+
'contents' => [
130+
['role' => 'user', 'parts' => [['text' => 'Hello']]],
131+
['role' => 'model', 'parts' => [['text' => 'Great to meet you. What would you like to know?']]],
132+
['role' => 'user', 'parts' => [['text' => 'I have two dogs in my house. How many paws are in my house?']]],
133+
],
134+
],
135+
];
136+
137+
yield 'with system messages' => [
138+
new MessageBag(
139+
Message::forSystem('You are a cat. Your name is Neko.'),
140+
Message::ofUser('Hello there'),
141+
),
142+
[
143+
'contents' => [
144+
['role' => 'user', 'parts' => [['text' => 'Hello there']]],
145+
],
146+
'systemInstruction' => [
147+
'parts' => [['text' => 'You are a cat. Your name is Neko.']],
148+
],
149+
],
150+
];
151+
}
152+
}

0 commit comments

Comments
 (0)