diff --git a/src/platform/composer.json b/src/platform/composer.json index fc7da2e3e..ce482e5ed 100644 --- a/src/platform/composer.json +++ b/src/platform/composer.json @@ -49,7 +49,8 @@ "symfony/property-info": "^6.4 || ^7.1", "symfony/serializer": "^6.4 || ^7.1", "symfony/type-info": "^7.2.3", - "symfony/uid": "^6.4 || ^7.1" + "symfony/uid": "^6.4 || ^7.1", + "symfony/event-dispatcher": "^6.4 || ^7.1" }, "require-dev": { "async-aws/bedrock-runtime": "^0.1.0", @@ -61,7 +62,6 @@ "symfony/ai-agent": "@dev", "symfony/console": "^6.4 || ^7.1", "symfony/dotenv": "^6.4 || ^7.1", - "symfony/event-dispatcher": "^6.4 || ^7.1", "symfony/finder": "^6.4 || ^7.1", "symfony/process": "^6.4 || ^7.1", "symfony/var-dumper": "^6.4 || ^7.1" diff --git a/src/platform/src/Event/PlatformInvokationEvent.php b/src/platform/src/Event/PlatformInvokationEvent.php new file mode 100644 index 000000000..423f1d45e --- /dev/null +++ b/src/platform/src/Event/PlatformInvokationEvent.php @@ -0,0 +1,34 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Event; + +use Symfony\AI\Platform\Model; +use Symfony\Contracts\EventDispatcher\Event; + +/** + * Event dispatched before platform invocation to allow modification of input data. + * + * @author Ramy Hakam + */ +final class PlatformInvokationEvent extends Event +{ + /** + * @param array|string|object $input + * @param array $options + */ + public function __construct( + public readonly Model $model, + public array|string|object $input, + public readonly array $options = [], + ) { + } +} diff --git a/src/platform/src/EventListener/StringToMessageBagListener.php b/src/platform/src/EventListener/StringToMessageBagListener.php new file mode 100644 index 000000000..d6b948505 --- /dev/null +++ b/src/platform/src/EventListener/StringToMessageBagListener.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\EventListener; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Event\PlatformInvokationEvent; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +/** + * Converts string inputs to MessageBag for models that support INPUT_MESSAGES capability. + * + * @author Ramy Hakam + */ +final class StringToMessageBagListener +{ + public function __invoke(PlatformInvokationEvent $event): void + { + // Only process string inputs + if (!\is_string($event->input)) { + return; + } + + // Only process models that support INPUT_MESSAGES capability + if (!$event->model->supports(Capability::INPUT_MESSAGES)) { + return; + } + + // Convert string to MessageBag with a user message + $event->input = new MessageBag(Message::ofUser($event->input)); + } +} diff --git a/src/platform/src/Platform.php b/src/platform/src/Platform.php index de5a16dc9..0918664fb 100644 --- a/src/platform/src/Platform.php +++ b/src/platform/src/Platform.php @@ -11,6 +11,8 @@ namespace Symfony\AI\Platform; +use Psr\EventDispatcher\EventDispatcherInterface; +use Symfony\AI\Platform\Event\PlatformInvokationEvent; use Symfony\AI\Platform\Exception\RuntimeException; use Symfony\AI\Platform\Result\RawResultInterface; use Symfony\AI\Platform\Result\ResultPromise; @@ -38,6 +40,7 @@ public function __construct( iterable $modelClients, iterable $resultConverters, private ?Contract $contract = null, + private ?EventDispatcherInterface $eventDispatcher = null, ) { $this->contract = $contract ?? Contract::create(); $this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients; @@ -46,6 +49,13 @@ public function __construct( public function invoke(Model $model, array|string|object $input, array $options = []): ResultPromise { + // Dispatch event to allow input modification + if ($this->eventDispatcher) { + $event = new PlatformInvokationEvent($model, $input, $options); + $this->eventDispatcher->dispatch($event); + $input = $event->input; + } + $payload = $this->contract->createRequestPayload($model, $input); $options = array_merge($model->getOptions(), $options); diff --git a/src/platform/tests/Event/PlatformInvokationEventTest.php b/src/platform/tests/Event/PlatformInvokationEventTest.php new file mode 100644 index 000000000..79615e47c --- /dev/null +++ b/src/platform/tests/Event/PlatformInvokationEventTest.php @@ -0,0 +1,72 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Event; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Event\PlatformInvokationEvent; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Model; + +#[CoversClass(PlatformInvokationEvent::class)] +final class PlatformInvokationEventTest extends TestCase +{ + public function testGettersReturnCorrectValues() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + $input = 'Hello, world!'; + $options = ['temperature' => 0.7]; + + $event = new PlatformInvokationEvent($model, $input, $options); + + $this->assertSame($model, $event->model); + $this->assertSame($input, $event->input); + $this->assertSame($options, $event->options); + } + + public function testSetInputChangesInput() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + $originalInput = 'Hello, world!'; + $newInput = new MessageBag(Message::ofUser('Hello, world!')); + + $event = new PlatformInvokationEvent($model, $originalInput); + $event->input = $newInput; + + $this->assertSame($newInput, $event->input); + } + + public function testWorksWithDifferentInputTypes() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + // Test with string + $stringEvent = new PlatformInvokationEvent($model, 'string input'); + $this->assertIsString($stringEvent->input); + + // Test with array + $arrayEvent = new PlatformInvokationEvent($model, ['key' => 'value']); + $this->assertIsArray($arrayEvent->input); + + // Test with object + $objectInput = new MessageBag(); + $objectEvent = new PlatformInvokationEvent($model, $objectInput); + $this->assertSame($objectInput, $objectEvent->input); + } +} diff --git a/src/platform/tests/EventListener/StringToMessageBagListenerTest.php b/src/platform/tests/EventListener/StringToMessageBagListenerTest.php new file mode 100644 index 000000000..a46b53a94 --- /dev/null +++ b/src/platform/tests/EventListener/StringToMessageBagListenerTest.php @@ -0,0 +1,89 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\EventListener; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Event\PlatformInvokationEvent; +use Symfony\AI\Platform\EventListener\StringToMessageBagListener; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; + +#[CoversClass(StringToMessageBagListener::class)] +final class StringToMessageBagListenerTest extends TestCase +{ + public function testConvertsStringInputToMessageBagForMessagesCapableModel() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + $event = new PlatformInvokationEvent($model, 'Hello, world!'); + $listener = new StringToMessageBagListener(); + + $listener($event); + + $this->assertInstanceOf(MessageBag::class, $event->input); + $this->assertCount(1, $event->input->getMessages()); + $message = $event->input->getMessages()[0]; + $this->assertInstanceOf(UserMessage::class, $message); + $this->assertCount(1, $message->content); + $content = $message->content[0]; + $this->assertInstanceOf(Text::class, $content); + $this->assertSame('Hello, world!', $content->text); + } + + public function testDoesNotConvertStringInputForNonMessagesCapableModel() + { + $model = new class('test-model', [Capability::INPUT_TEXT, Capability::OUTPUT_TEXT]) extends Model { + }; + + $originalInput = 'Hello, world!'; + $event = new PlatformInvokationEvent($model, $originalInput); + $listener = new StringToMessageBagListener(); + + $listener($event); + + $this->assertSame($originalInput, $event->input); + } + + public function testDoesNotConvertNonStringInput() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + $originalInput = new MessageBag(Message::ofUser('Hello')); + $event = new PlatformInvokationEvent($model, $originalInput); + $listener = new StringToMessageBagListener(); + + $listener($event); + + $this->assertSame($originalInput, $event->input); + } + + public function testDoesNotConvertArrayInput() + { + $model = new class('test-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]) extends Model { + }; + + $originalInput = ['key' => 'value']; + $event = new PlatformInvokationEvent($model, $originalInput); + $listener = new StringToMessageBagListener(); + + $listener($event); + + $this->assertSame($originalInput, $event->input); + } +}