Skip to content

Commit d4a2fba

Browse files
committed
fix
1 parent 3736faf commit d4a2fba

File tree

4 files changed

+189
-6
lines changed

4 files changed

+189
-6
lines changed

src/agent/src/Bridge/Meilisearch/MessageStore.php

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Symfony\AI\Agent\Chat\InitializableMessageStoreInterface;
1515
use Symfony\AI\Agent\Chat\MessageStoreInterface;
1616
use Symfony\AI\Agent\Exception\InvalidArgumentException;
17+
use Symfony\AI\Agent\Exception\LogicException;
1718
use Symfony\AI\Platform\Message\AssistantMessage;
1819
use Symfony\AI\Platform\Message\Content\Audio;
1920
use Symfony\AI\Platform\Message\Content\ContentInterface;
@@ -26,7 +27,9 @@
2627
use Symfony\AI\Platform\Message\MessageBagInterface;
2728
use Symfony\AI\Platform\Message\MessageInterface;
2829
use Symfony\AI\Platform\Message\SystemMessage;
30+
use Symfony\AI\Platform\Message\ToolCallMessage;
2931
use Symfony\AI\Platform\Message\UserMessage;
32+
use Symfony\AI\Platform\Result\ToolCall;
3033
use Symfony\Contracts\HttpClient\HttpClientInterface;
3134

3235
/**
@@ -77,7 +80,7 @@ public function initialize(array $options = []): void
7780
}
7881

7982
/**
80-
* @param array<string, mixed> $payload
83+
* @param array<string, mixed>|list<array<string, mixed>> $payload
8184
*
8285
* @return array<string, mixed>
8386
*/
@@ -94,12 +97,28 @@ private function request(string $method, string $endpoint, array $payload = []):
9497
return $result->toArray();
9598
}
9699

100+
/**
101+
* @return array<string, mixed>
102+
*/
97103
private function convertToIndexableArray(MessageInterface $message): array
98104
{
105+
$toolsCalls = [];
106+
107+
if ($message instanceof AssistantMessage && $message->hasToolCalls()) {
108+
$toolsCalls = array_map(
109+
static fn (ToolCall $toolCall): array => $toolCall->jsonSerialize(),
110+
$message->toolCalls,
111+
);
112+
}
113+
114+
if ($message instanceof ToolCallMessage) {
115+
$toolsCalls = $message->toolCall->jsonSerialize();
116+
}
117+
99118
return [
100119
'id' => $message->getId()->toRfc4122(),
101120
'type' => $message::class,
102-
'content' => ($message instanceof SystemMessage || $message instanceof AssistantMessage) ? $message->content : '',
121+
'content' => ($message instanceof SystemMessage || $message instanceof AssistantMessage || $message instanceof ToolCallMessage) ? $message->content : '',
103122
'contentAsBase64' => ($message instanceof UserMessage && [] !== $message->content) ? array_map(
104123
static fn (ContentInterface $content) => [
105124
'type' => $content::class,
@@ -110,15 +129,18 @@ private function convertToIndexableArray(MessageInterface $message): array
110129
Audio::class => $content->asBase64(),
111130
ImageUrl::class,
112131
DocumentUrl::class => $content->url,
113-
default => throw new \LogicException(\sprintf('Unknown content type "%s".', $content::class)),
132+
default => throw new LogicException(\sprintf('Unknown content type "%s".', $content::class)),
114133
},
115134
],
116135
$message->content,
117136
) : [],
118-
'toolsCalls' => ($message instanceof AssistantMessage && $message->hasToolCalls()) ? $message->toolCalls : [],
137+
'toolsCalls' => $toolsCalls,
119138
];
120139
}
121140

141+
/**
142+
* @param array<string, mixed> $payload
143+
*/
122144
private function convertToMessage(array $payload): MessageInterface
123145
{
124146
$type = $payload['type'];
@@ -127,14 +149,29 @@ private function convertToMessage(array $payload): MessageInterface
127149

128150
return match ($type) {
129151
SystemMessage::class => new SystemMessage($content),
130-
AssistantMessage::class => new AssistantMessage($content, $payload['toolsCalls'] ?? []),
152+
AssistantMessage::class => new AssistantMessage($content, array_map(
153+
static fn (array $toolsCall): ToolCall => new ToolCall(
154+
$toolsCall['id'],
155+
$toolsCall['function']['name'],
156+
json_decode($toolsCall['function']['arguments'], true)
157+
),
158+
$payload['toolsCalls'],
159+
)),
131160
UserMessage::class => new UserMessage(...array_map(
132161
static fn (array $contentAsBase64) => \in_array($contentAsBase64['type'], [File::class, Image::class, Audio::class], true)
133162
? $contentAsBase64['type']::fromDataUrl($contentAsBase64['content'])
134163
: new $contentAsBase64['type']($contentAsBase64['content']),
135164
$contentAsBase64,
136165
)),
137-
default => throw new \LogicException(\sprintf('Unknown message type "%s".', $type)),
166+
ToolCallMessage::class => new ToolCallMessage(
167+
new ToolCall(
168+
$payload['toolsCalls']['id'],
169+
$payload['toolsCalls']['function']['name'],
170+
json_decode($payload['toolsCalls']['function']['arguments'], true)
171+
),
172+
$content
173+
),
174+
default => throw new LogicException(\sprintf('Unknown message type "%s".', $type)),
138175
};
139176
}
140177
}

src/agent/tests/Bridge/Meilisearch/MessageStoreTest.php

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,20 @@
1212
namespace Symfony\AI\Agent\Tests\Bridge\Meilisearch;
1313

1414
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\DataProvider;
1516
use PHPUnit\Framework\TestCase;
1617
use Symfony\AI\Agent\Bridge\Meilisearch\MessageStore;
18+
use Symfony\AI\Platform\Message\AssistantMessage;
19+
use Symfony\AI\Platform\Message\Content\Text;
1720
use Symfony\AI\Platform\Message\Message;
1821
use Symfony\AI\Platform\Message\MessageBag;
22+
use Symfony\AI\Platform\Message\SystemMessage;
23+
use Symfony\AI\Platform\Message\ToolCallMessage;
24+
use Symfony\AI\Platform\Message\UserMessage;
1925
use Symfony\Component\HttpClient\Exception\ClientException;
2026
use Symfony\Component\HttpClient\MockHttpClient;
2127
use Symfony\Component\HttpClient\Response\JsonMockResponse;
28+
use Symfony\Component\Uid\Uuid;
2229

2330
#[CoversClass(MessageStore::class)]
2431
final class MessageStoreTest extends TestCase
@@ -161,4 +168,143 @@ public function testStoreCannotRetrieveMessagesOnInvalidResponse()
161168
self::expectExceptionCode(400);
162169
$store->load();
163170
}
171+
172+
#[DataProvider('provideMessages')]
173+
public function testStoreCanRetrieveMessages(array $payload)
174+
{
175+
$httpClient = new MockHttpClient([
176+
new JsonMockResponse([
177+
'results' => [
178+
$payload,
179+
],
180+
], [
181+
'http_code' => 200,
182+
]),
183+
], 'http://localhost:7700');
184+
185+
$store = new MessageStore(
186+
$httpClient,
187+
'http://localhost:7700',
188+
'test',
189+
'test',
190+
);
191+
192+
$messageBag = $store->load();
193+
194+
$this->assertCount(1, $messageBag);
195+
$this->assertSame(1, $httpClient->getRequestsCount());
196+
}
197+
198+
public function testStoreCannotDeleteMessagesOnInvalidResponse()
199+
{
200+
$httpClient = new MockHttpClient([
201+
new JsonMockResponse([
202+
'message' => 'error',
203+
'code' => 'index_not_found',
204+
'type' => 'invalid_request',
205+
'link' => 'https://docs.meilisearch.com/errors#index_not_found',
206+
], [
207+
'http_code' => 400,
208+
]),
209+
], 'http://localhost:7700');
210+
211+
$store = new MessageStore(
212+
$httpClient,
213+
'http://localhost:7700',
214+
'test',
215+
'test',
216+
);
217+
218+
self::expectException(ClientException::class);
219+
self::expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes/test/documents".');
220+
self::expectExceptionCode(400);
221+
$store->clear();
222+
}
223+
224+
public function testStoreCanDelete()
225+
{
226+
$httpClient = new MockHttpClient([
227+
new JsonMockResponse([
228+
'taskUid' => 1,
229+
'indexUid' => 'test',
230+
'status' => 'enqueued',
231+
'type' => 'indexDeletion',
232+
'enqueuedAt' => '2025-01-01T00:00:00Z',
233+
], [
234+
'http_code' => 200,
235+
]),
236+
], 'http://localhost:7700');
237+
238+
$store = new MessageStore(
239+
$httpClient,
240+
'http://localhost:7700',
241+
'test',
242+
'test',
243+
);
244+
245+
$store->clear();
246+
247+
$this->assertSame(1, $httpClient->getRequestsCount());
248+
}
249+
250+
public static function provideMessages(): \Generator
251+
{
252+
yield UserMessage::class => [
253+
[
254+
'id' => Uuid::v7()->toRfc4122(),
255+
'type' => UserMessage::class,
256+
'content' => '',
257+
'contentAsBase64' => [
258+
[
259+
'type' => Text::class,
260+
'content' => 'What is the Symfony framework?',
261+
],
262+
],
263+
'toolsCalls' => [],
264+
],
265+
];
266+
yield SystemMessage::class => [
267+
[
268+
'id' => Uuid::v7()->toRfc4122(),
269+
'type' => SystemMessage::class,
270+
'content' => 'Hello there',
271+
'contentAsBase64' => [],
272+
'toolsCalls' => [],
273+
],
274+
];
275+
yield AssistantMessage::class => [
276+
[
277+
'id' => Uuid::v7()->toRfc4122(),
278+
'type' => AssistantMessage::class,
279+
'content' => 'Hello there',
280+
'contentAsBase64' => [],
281+
'toolsCalls' => [
282+
[
283+
'id' => '1',
284+
'name' => 'foo',
285+
'function' => [
286+
'name' => 'foo',
287+
'arguments' => '{}',
288+
],
289+
],
290+
],
291+
],
292+
];
293+
yield ToolCallMessage::class => [
294+
[
295+
'id' => Uuid::v7()->toRfc4122(),
296+
'type' => ToolCallMessage::class,
297+
'content' => 'Hello there',
298+
'contentAsBase64' => [],
299+
'toolsCalls' => [
300+
'id' => '1',
301+
'name' => 'foo',
302+
'function' => [
303+
'name' => 'foo',
304+
'arguments' => '{}',
305+
],
306+
],
307+
],
308+
];
309+
}
164310
}

0 commit comments

Comments
 (0)