Skip to content

Commit ca8f9be

Browse files
HaKIMusOskarStark
authored andcommitted
[Agent][Platform] Add support for native union types and list of polymporphic* types by using DiscriminatorMap
1 parent 3901999 commit ca8f9be

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

src/Contract/JsonSchema/Factory.php

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With;
1515
use Symfony\AI\Platform\Exception\InvalidArgumentException;
16+
use Symfony\Component\Serializer\Attribute\DiscriminatorMap;
1617
use Symfony\Component\TypeInfo\Type;
1718
use Symfony\Component\TypeInfo\Type\BackedEnumType;
1819
use Symfony\Component\TypeInfo\Type\BuiltinType;
1920
use Symfony\Component\TypeInfo\Type\CollectionType;
2021
use Symfony\Component\TypeInfo\Type\NullableType;
2122
use Symfony\Component\TypeInfo\Type\ObjectType;
23+
use Symfony\Component\TypeInfo\Type\UnionType;
2224
use Symfony\Component\TypeInfo\TypeIdentifier;
2325
use Symfony\Component\TypeInfo\TypeResolver\TypeResolver;
2426

@@ -47,6 +49,7 @@
4749
* minProperties?: int,
4850
* maxProperties?: int,
4951
* dependentRequired?: bool,
52+
* anyOf?: list<mixed>,
5053
* }>,
5154
* required: list<string>,
5255
* additionalProperties: false,
@@ -110,7 +113,10 @@ private function convertTypes(array $elements): ?array
110113
$schema = $this->getTypeSchema($type);
111114

112115
if ($type->isNullable()) {
113-
$schema['type'] = [$schema['type'], 'null'];
116+
// anyOf already contains the null variant when applicable; do nothing
117+
if (!isset($schema['anyOf'])) {
118+
$schema['type'] = [$schema['type'], 'null'];
119+
}
114120
} elseif (!($element instanceof \ReflectionParameter && $element->isOptional())) {
115121
$result['required'][] = $name;
116122
}
@@ -151,6 +157,21 @@ private function getTypeSchema(Type $type): array
151157
}
152158
}
153159

160+
if ($type instanceof UnionType) {
161+
// Do not handle nullables as a union but directly return the wrapped type schema
162+
if (2 === \count($type->getTypes()) && $type->isNullable() && $type instanceof NullableType) {
163+
return $this->getTypeSchema($type->getWrappedType());
164+
}
165+
166+
$variants = [];
167+
168+
foreach ($type->getTypes() as $variant) {
169+
$variants[] = $this->getTypeSchema($variant);
170+
}
171+
172+
return ['anyOf' => $variants];
173+
}
174+
154175
switch (true) {
155176
case $type->isIdentifiedBy(TypeIdentifier::INT):
156177
return ['type' => 'integer'];
@@ -168,6 +189,22 @@ private function getTypeSchema(Type $type): array
168189
if ($collectionValueType->isIdentifiedBy(TypeIdentifier::OBJECT)) {
169190
\assert($collectionValueType instanceof ObjectType);
170191

192+
// Check for the DiscriminatorMap attribute to handle polymorphic arrays
193+
$discriminatorMapping = $this->findDiscriminatorMapping($collectionValueType->getClassName());
194+
if ($discriminatorMapping) {
195+
$discriminators = [];
196+
foreach ($discriminatorMapping as $_ => $discriminator) {
197+
$discriminators[] = $this->buildProperties($discriminator);
198+
}
199+
200+
return [
201+
'type' => 'array',
202+
'items' => [
203+
'anyOf' => $discriminators,
204+
],
205+
];
206+
}
207+
171208
return [
172209
'type' => 'array',
173210
'items' => $this->buildProperties($collectionValueType->getClassName()),
@@ -195,6 +232,8 @@ private function getTypeSchema(Type $type): array
195232
}
196233

197234
// no break
235+
case $type->isIdentifiedBy(TypeIdentifier::NULL):
236+
return ['type' => 'null'];
198237
case $type->isIdentifiedBy(TypeIdentifier::STRING):
199238
default:
200239
// Fallback to string for any unhandled types
@@ -233,4 +272,34 @@ private function buildEnumSchema(string $enumClassName): array
233272
'enum' => $values,
234273
];
235274
}
275+
276+
/**
277+
* @param class-string $className
278+
*
279+
* @return array<string, class-string>|null
280+
*
281+
* @throws \ReflectionException
282+
*/
283+
private function findDiscriminatorMapping(string $className): ?array
284+
{
285+
/** @var \ReflectionAttribute<DiscriminatorMap>[] $attributes */
286+
$attributes = (new \ReflectionClass($className))->getAttributes(DiscriminatorMap::class);
287+
$result = \count($attributes) > 0 ? $attributes[array_key_first($attributes)]->newInstance() : null;
288+
289+
if (!$result) {
290+
return null;
291+
}
292+
293+
/**
294+
* In the 8.* release of symfony/serializer DiscriminatorMap removes the getMapping() method in favor of property access.
295+
* This satisfies the project's pipeline that builds against both < and >= 8.* release.
296+
* This logic can be removed once the project builds against >= 8.* only.
297+
*
298+
* @see https://github.com/symfony/ai/pull/585#issuecomment-3303631346
299+
*/
300+
$reflectionProperty = new \ReflectionProperty($result, 'mapping');
301+
$reflectionProperty->setAccessible(true);
302+
303+
return $reflectionProperty->getValue($result);
304+
}
236305
}

tests/Contract/JsonSchema/FactoryTest.php

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
use PHPUnit\Framework\TestCase;
1717
use Symfony\AI\Fixtures\StructuredOutput\ExampleDto;
1818
use Symfony\AI\Fixtures\StructuredOutput\MathReasoning;
19+
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListOfPolymorphicTypesDto;
1920
use Symfony\AI\Fixtures\StructuredOutput\Step;
21+
use Symfony\AI\Fixtures\StructuredOutput\UnionType\UnionTypeDto;
2022
use Symfony\AI\Fixtures\StructuredOutput\User;
2123
use Symfony\AI\Fixtures\Tool\ToolNoParams;
2224
use Symfony\AI\Fixtures\Tool\ToolOptionalParam;
@@ -226,6 +228,100 @@ public function testBuildPropertiesForMathReasoningClass()
226228
$this->assertSame($expected, $actual);
227229
}
228230

231+
public function testBuildPropertiesForListOfPolymorphicTypesDto()
232+
{
233+
$expected = [
234+
'type' => 'object',
235+
'properties' => [
236+
'items' => [
237+
'type' => 'array',
238+
'items' => [
239+
'anyOf' => [
240+
[
241+
'type' => 'object',
242+
'properties' => [
243+
'name' => ['type' => 'string'],
244+
'type' => [
245+
'type' => 'string',
246+
'pattern' => '^name$',
247+
],
248+
],
249+
'required' => [
250+
'name',
251+
'type',
252+
],
253+
'additionalProperties' => false,
254+
],
255+
[
256+
'type' => 'object',
257+
'properties' => [
258+
'age' => ['type' => 'integer'],
259+
'type' => [
260+
'type' => 'string',
261+
'pattern' => '^age$',
262+
],
263+
],
264+
'required' => [
265+
'age',
266+
'type',
267+
],
268+
'additionalProperties' => false,
269+
],
270+
],
271+
],
272+
],
273+
],
274+
'required' => ['items'],
275+
'additionalProperties' => false,
276+
];
277+
278+
$actual = $this->factory->buildProperties(ListOfPolymorphicTypesDto::class);
279+
280+
$this->assertSame($expected, $actual);
281+
$this->assertSame($expected['type'], $actual['type']);
282+
$this->assertSame($expected['required'], $actual['required']);
283+
}
284+
285+
public function testBuildPropertiesForUnionTypeDto()
286+
{
287+
$expected = [
288+
'type' => 'object',
289+
'properties' => [
290+
'time' => [
291+
'anyOf' => [
292+
[
293+
'type' => 'object',
294+
'properties' => [
295+
'readableTime' => ['type' => 'string'],
296+
],
297+
'required' => ['readableTime'],
298+
'additionalProperties' => false,
299+
],
300+
[
301+
'type' => 'object',
302+
'properties' => [
303+
'timestamp' => ['type' => 'integer'],
304+
],
305+
'required' => ['timestamp'],
306+
'additionalProperties' => false,
307+
],
308+
[
309+
'type' => 'null',
310+
],
311+
],
312+
],
313+
],
314+
'required' => [],
315+
'additionalProperties' => false,
316+
];
317+
318+
$actual = $this->factory->buildProperties(UnionTypeDto::class);
319+
320+
$this->assertSame($expected, $actual);
321+
$this->assertSame($expected['type'], $actual['type']);
322+
$this->assertSame($expected['required'], $actual['required']);
323+
}
324+
229325
public function testBuildPropertiesForStepClass()
230326
{
231327
$expected = [

0 commit comments

Comments
 (0)