Skip to content

Commit 1c0df52

Browse files
VincentLangletchr-hertel
authored andcommitted
Use custom serializer for PlatformSubscriber
1 parent d06d642 commit 1c0df52

File tree

6 files changed

+107
-39
lines changed

6 files changed

+107
-39
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Fixtures\StructuredOutput;
13+
14+
use Symfony\Component\Serializer\Attribute\Ignore;
15+
use Symfony\Component\Serializer\Attribute\SerializedName;
16+
17+
final class MathReasoningWithAttributes
18+
{
19+
/**
20+
* @param Step[] $steps
21+
*/
22+
public function __construct(
23+
public array $steps,
24+
#[SerializedName('foo')]
25+
public string $finalAnswer,
26+
#[Ignore]
27+
public float $result,
28+
) {
29+
}
30+
}

src/ai-bundle/config/services.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
use Symfony\AI\Platform\Contract;
6363
use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser;
6464
use Symfony\AI\Platform\Contract\JsonSchema\Factory as SchemaFactory;
65+
use Symfony\AI\Platform\Serializer\StructuredOutputSerializer;
6566
use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber;
6667
use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactory;
6768
use Symfony\AI\Platform\StructuredOutput\ResponseFormatFactoryInterface;
@@ -122,10 +123,11 @@
122123
service('type_info.resolver')->nullOnInvalid(),
123124
])
124125
->alias(ResponseFormatFactoryInterface::class, 'ai.platform.response_format_factory')
126+
->set('ai.platform.structured_output_serializer', StructuredOutputSerializer::class)
125127
->set('ai.platform.structured_output_subscriber', PlatformSubscriber::class)
126128
->args([
127129
service('ai.agent.response_format_factory'),
128-
service('serializer'),
130+
service('ai.platform.structured_output_serializer'),
129131
])
130132
->tag('kernel.event_subscriber')
131133

src/platform/src/Contract/JsonSchema/Factory.php

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ private function findDiscriminatorMapping(string $className): ?array
300300
* @see https://github.com/symfony/ai/pull/585#issuecomment-3303631346
301301
*/
302302
$reflectionProperty = new \ReflectionProperty($result, 'mapping');
303-
$reflectionProperty->setAccessible(true);
304303

305304
return $reflectionProperty->getValue($result);
306305
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Platform\Serializer;
13+
14+
use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor;
15+
use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor;
16+
use Symfony\Component\PropertyInfo\PropertyInfoExtractor;
17+
use Symfony\Component\Serializer\Encoder\JsonEncoder;
18+
use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata;
19+
use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory;
20+
use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader;
21+
use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer;
22+
use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer;
23+
use Symfony\Component\Serializer\Normalizer\ObjectNormalizer;
24+
use Symfony\Component\Serializer\Serializer;
25+
26+
class StructuredOutputSerializer extends Serializer
27+
{
28+
/*
29+
* Custom serializer made to deserialize StructuredOutput.
30+
*
31+
* Since field name are generated by the `Symfony\AI\Platform\Contract\JsonSchema\Factory`
32+
* without using the serializer (and the serializer metadata/attributes), we have to ignore them
33+
* again when deserializing the data by not passing `classMetadataFactory` to ObjectNormalizer.
34+
*/
35+
public function __construct()
36+
{
37+
$classMetadataFactory = new ClassMetadataFactory(new AttributeLoader());
38+
$discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory);
39+
$propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]);
40+
41+
$normalizers = [
42+
new BackedEnumNormalizer(),
43+
new ObjectNormalizer(
44+
propertyTypeExtractor: $propertyInfo,
45+
classDiscriminatorResolver: $discriminator,
46+
),
47+
new ArrayDenormalizer(),
48+
];
49+
50+
parent::__construct($normalizers, [new JsonEncoder()]);
51+
}
52+
}

src/platform/src/StructuredOutput/PlatformSubscriber.php

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,8 @@
1717
use Symfony\AI\Platform\Exception\InvalidArgumentException;
1818
use Symfony\AI\Platform\Exception\MissingModelSupportException;
1919
use Symfony\AI\Platform\Result\DeferredResult;
20+
use Symfony\AI\Platform\Serializer\StructuredOutputSerializer;
2021
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
21-
use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor;
22-
use Symfony\Component\PropertyInfo\Extractor\ReflectionExtractor;
23-
use Symfony\Component\PropertyInfo\PropertyInfoExtractor;
24-
use Symfony\Component\Serializer\Encoder\JsonEncoder;
25-
use Symfony\Component\Serializer\Mapping\ClassDiscriminatorFromClassMetadata;
26-
use Symfony\Component\Serializer\Mapping\Factory\ClassMetadataFactory;
27-
use Symfony\Component\Serializer\Mapping\Loader\AttributeLoader;
28-
use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer;
29-
use Symfony\Component\Serializer\Normalizer\BackedEnumNormalizer;
30-
use Symfony\Component\Serializer\Normalizer\ObjectNormalizer;
31-
use Symfony\Component\Serializer\Serializer;
3222
use Symfony\Component\Serializer\SerializerInterface;
3323

3424
/**
@@ -40,29 +30,13 @@ final class PlatformSubscriber implements EventSubscriberInterface
4030

4131
private string $outputType;
4232

33+
private SerializerInterface $serializer;
34+
4335
public function __construct(
4436
private readonly ResponseFormatFactoryInterface $responseFormatFactory = new ResponseFormatFactory(),
45-
private ?SerializerInterface $serializer = null,
37+
?SerializerInterface $serializer = null,
4638
) {
47-
if (null !== $this->serializer) {
48-
return;
49-
}
50-
51-
$classMetadataFactory = new ClassMetadataFactory(new AttributeLoader());
52-
$discriminator = new ClassDiscriminatorFromClassMetadata($classMetadataFactory);
53-
$propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor(), new ReflectionExtractor()]);
54-
55-
$normalizers = [
56-
new BackedEnumNormalizer(),
57-
new ObjectNormalizer(
58-
classMetadataFactory: $classMetadataFactory,
59-
propertyTypeExtractor: $propertyInfo,
60-
classDiscriminatorResolver: $discriminator,
61-
),
62-
new ArrayDenormalizer(),
63-
];
64-
65-
$this->serializer = new Serializer($normalizers, [new JsonEncoder()]);
39+
$this->serializer = $serializer ?? new StructuredOutputSerializer();
6640
}
6741

6842
public static function getSubscribedEvents(): array

src/platform/tests/StructuredOutput/PlatformSubscriberTest.php

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use PHPUnit\Framework\TestCase;
1616
use Symfony\AI\Fixtures\SomeStructure;
1717
use Symfony\AI\Fixtures\StructuredOutput\MathReasoning;
18+
use Symfony\AI\Fixtures\StructuredOutput\MathReasoningWithAttributes;
1819
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemAge;
1920
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListItemName;
2021
use Symfony\AI\Fixtures\StructuredOutput\PolymorphicType\ListOfPolymorphicTypesDto;
@@ -35,7 +36,6 @@
3536
use Symfony\AI\Platform\Result\TextResult;
3637
use Symfony\AI\Platform\StructuredOutput\PlatformSubscriber;
3738
use Symfony\AI\Platform\Test\PlainConverter;
38-
use Symfony\Component\Serializer\SerializerInterface;
3939

4040
final class PlatformSubscriberTest extends TestCase
4141
{
@@ -95,12 +95,16 @@ public function testProcessOutputWithResponseFormat()
9595
$this->assertSame('data', $deferredResult->asObject()->some);
9696
}
9797

98-
public function testProcessOutputWithComplexResponseFormat()
98+
/**
99+
* @param class-string $class
100+
*/
101+
#[DataProvider('complexFormatDataProvider')]
102+
public function testProcessOutputWithComplexResponseFormat(string $class)
99103
{
100104
$processor = new PlatformSubscriber(new ConfigurableResponseFormatFactory(['some' => 'format']));
101105

102106
$model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]);
103-
$options = ['response_format' => MathReasoning::class];
107+
$options = ['response_format' => $class];
104108
$invocationEvent = new InvocationEvent($model, new MessageBag(), $options);
105109
$processor->processInput($invocationEvent);
106110

@@ -139,7 +143,7 @@ public function testProcessOutputWithComplexResponseFormat()
139143

140144
$deferredResult = $resultEvent->getDeferredResult();
141145
$this->assertInstanceOf(ObjectResult::class, $result = $deferredResult->getResult());
142-
$this->assertInstanceOf(MathReasoning::class, $structure = $deferredResult->asObject());
146+
$this->assertInstanceOf($class, $structure = $deferredResult->asObject());
143147
$this->assertInstanceOf(Metadata::class, $result->getMetadata());
144148
$this->assertCount(5, $structure->steps);
145149
$this->assertInstanceOf(Step::class, $structure->steps[0]);
@@ -151,6 +155,14 @@ public function testProcessOutputWithComplexResponseFormat()
151155
$this->assertSame(-3.75, $structure->result);
152156
}
153157

158+
public static function complexFormatDataProvider(): iterable
159+
{
160+
return [
161+
[MathReasoning::class],
162+
[MathReasoningWithAttributes::class],
163+
];
164+
}
165+
154166
/**
155167
* @param class-string $expectedTimeStructure
156168
*/
@@ -254,8 +266,7 @@ public function testProcessOutputWithCorrectPolymorphicTypesResponseFormat()
254266
public function testProcessOutputWithoutResponseFormat()
255267
{
256268
$resultFormatFactory = new ConfigurableResponseFormatFactory();
257-
$serializer = self::createMock(SerializerInterface::class);
258-
$processor = new PlatformSubscriber($resultFormatFactory, $serializer);
269+
$processor = new PlatformSubscriber($resultFormatFactory);
259270

260271
$converter = new PlainConverter($result = new TextResult('{"some": "data"}'));
261272
$deferred = new DeferredResult($converter, new InMemoryRawResult());

0 commit comments

Comments
 (0)