diff --git a/src/mcp-bundle/composer.json b/src/mcp-bundle/composer.json index 59bd7557..156ace2b 100644 --- a/src/mcp-bundle/composer.json +++ b/src/mcp-bundle/composer.json @@ -21,7 +21,8 @@ }, "require-dev": { "phpstan/phpstan": "^2.1", - "phpunit/phpunit": "^11.5" + "phpunit/phpunit": "^11.5", + "symfony/security-bundle": "^7.3" }, "config": { "sort-packages": true diff --git a/src/mcp-bundle/config/options.php b/src/mcp-bundle/config/options.php index 2f06d288..4b86a3b8 100644 --- a/src/mcp-bundle/config/options.php +++ b/src/mcp-bundle/config/options.php @@ -59,6 +59,7 @@ ->children() ->booleanNode('stdio')->defaultFalse()->end() ->booleanNode('sse')->defaultFalse()->end() + ->booleanNode('http_stream')->defaultTrue()->end() // @todo change to default false ->end() ->end() ->end() diff --git a/src/mcp-bundle/config/routes.php b/src/mcp-bundle/config/routes.php index fe499bf9..28426c01 100644 --- a/src/mcp-bundle/config/routes.php +++ b/src/mcp-bundle/config/routes.php @@ -9,16 +9,29 @@ * file that was distributed with this source code. */ -use Symfony\AI\McpBundle\Controller\McpController; +use Symfony\AI\McpBundle\Controller\McpSseController; +use Symfony\AI\McpBundle\Controller\McpHttpStreamController; use Symfony\Component\Routing\Loader\Configurator\RoutingConfigurator; return function (RoutingConfigurator $routes): void { $routes->add('_mcp_sse', '/sse') - ->controller([McpController::class, 'sse']) + ->controller([McpSseController::class, 'sse']) ->methods(['GET']) ; $routes->add('_mcp_messages', '/messages/{id}') - ->controller([McpController::class, 'messages']) + ->controller([McpSseController::class, 'messages']) ->methods(['POST']) ; + $routes->add('_mcp_http', '/http/') + ->controller([McpHttpStreamController::class, 'endpoint']) + ->methods(['POST']) + ; + $routes->add('_mcp_http_initiate_sse', '/http/') + ->controller([McpHttpStreamController::class, 'initiateSseFromStream']) + ->methods(['GET']) + ; + $routes->add('_mcp_http_delete_session', '/http/') + ->controller([McpHttpStreamController::class, 'deleteSession']) + ->methods(['DELETE']) + ; }; diff --git a/src/mcp-bundle/config/services.php b/src/mcp-bundle/config/services.php index 56d2c11e..242b15c9 100644 --- a/src/mcp-bundle/config/services.php +++ b/src/mcp-bundle/config/services.php @@ -11,6 +11,9 @@ namespace Symfony\Component\DependencyInjection\Loader\Configurator; +use Symfony\AI\McpBundle\Session\SessionIdentifierResolver; +use Symfony\AI\McpBundle\Session\SessionResolver; +use Symfony\AI\McpBundle\Session\SessionSubscriber; use Symfony\AI\McpSdk\Capability\ToolChain; use Symfony\AI\McpSdk\Message\Factory; use Symfony\AI\McpSdk\Server; @@ -21,6 +24,8 @@ use Symfony\AI\McpSdk\Server\RequestHandler\ToolCallHandler; use Symfony\AI\McpSdk\Server\RequestHandler\ToolListHandler; use Symfony\AI\McpSdk\Server\Transport\Sse\Store\CachePoolStore; +use Symfony\AI\McpSdk\Server\Transport\StreamableHttp\Session\SessionIdentifierFactory; +use Symfony\AI\McpSdk\Server\Transport\StreamableHttp\Session\SessionPoolStorage; return static function (ContainerConfigurator $container): void { $container->services() @@ -50,6 +55,7 @@ ->set('mcp.message_factory', Factory::class) ->args([]) + ->alias(Factory::class, 'mcp.message_factory') ->set('mcp.server.json_rpc', JsonRpcHandler::class) ->args([ service('mcp.message_factory'), @@ -57,6 +63,7 @@ tagged_iterator('mcp.server.notification_handler'), service('logger')->ignoreOnInvalid(), ]) + ->alias(JsonRpcHandler::class, 'mcp.server.json_rpc') ->set('mcp.server', Server::class) ->args([ service('mcp.server.json_rpc'), @@ -67,6 +74,34 @@ ->args([ service('cache.app'), ]) + ->set('mcp.server.http_stream.session.identifier_factory', SessionIdentifierFactory::class) + ->args([ + service('security')->nullOnInvalid(), + ]) + ->alias(SessionIdentifierFactory::class, 'mcp.server.http_stream.session.identifier_factory') + ->set('mcp.server.http_stream.session.identifier_resolver', SessionIdentifierResolver::class) + ->tag('controller.argument_value_resolver') + ->set('mcp.server.http_stream.session.resolver', SessionResolver::class) + ->tag('controller.argument_value_resolver') + ->set('mcp.server.http_stream.session.pool', SessionPoolStorage::class) + ->args([ + service('cache.app'), + param('mcp.http_stream.session.ttl'), + ]) + ->alias(Server\Transport\StreamableHttp\SessionStorageInterface::class, 'mcp.server.http_stream.session.pool') + ->set('mcp.server.http_stream.session.subscriber', SessionSubscriber::class) + ->args([ + service('mcp.server.http_stream.session.identifier_factory'), + service('mcp.server.http_stream.session.factory'), + ]) + ->tag('kernel.event_subscriber') + ->alias(SessionSubscriber::class, 'mcp.server.http_stream.session.subscriber') + ->set('mcp.server.http_stream.session.factory', Server\Transport\StreamableHttp\Session\SessionFactory::class) + ->args([ + service('mcp.server.http_stream.session.identifier_factory'), + service('mcp.server.http_stream.session.pool'), + ]) + ->alias(Server\Transport\StreamableHttp\Session\SessionFactory::class, 'mcp.server.http_stream.session.factory') ->set('mcp.tool_chain', ToolChain::class) ->args([ tagged_iterator('mcp.tool'), diff --git a/src/mcp-bundle/src/Controller/McpHttpStreamController.php b/src/mcp-bundle/src/Controller/McpHttpStreamController.php new file mode 100644 index 00000000..1c8e84a3 --- /dev/null +++ b/src/mcp-bundle/src/Controller/McpHttpStreamController.php @@ -0,0 +1,142 @@ +messageFactory->create($request->getContent()); + if ($session === null) { + // Must be an "initialize" request. If not ==> 404. + if ($message->method !== 'initialize') { // @todo do better + return new Response(null, Response::HTTP_NOT_FOUND); + } + $session = $this->sessionFactory->get(); + $session->save(); + } + + // Handle the input + // If response is streamable ==> open an SSE Stream and store all responses in session for later replay + // If response is not ==> JSON + + $response = $this->handler->handleMessage($message); + + if ($message instanceof Notification) { + return new Response(null, Response::HTTP_ACCEPTED); + } + if ($response instanceof StreamableResponse) { + //$transport = new Server\Transport\StreamableHttp\StreamTransport($session->addNewStream(), $session, $response->responses); + return new StreamedResponse(function () use ($session, $response) { + $streamId = $session->addNewStream(); + foreach (($response->responses)() as $response) { + $eventId = Uuid::v4()->toString(); + if (is_array($response)) { + $rawResponse = json_encode($response, \JSON_THROW_ON_ERROR); + } else { + $rawResponse = $this->handler->encodeResponse($response); + } + $session->addEventOnStream($streamId, $eventId, $rawResponse); + echo "id: $eventId\n"; + echo "type: notification\n"; + echo "data: " . $rawResponse . "\n\n"; + if (false !== ob_get_length()) { + ob_flush(); + } + flush(); + } + }, headers: [ + 'Content-Type' => 'text/event-stream', + 'Cache-Control' => 'no-cache', + 'X-Accel-Buffering' => 'no', + 'Mcp-Session-Id' => $session->sessionIdentifier->sessionId->toString(), + ]); + } + return new JsonResponse($this->handler->encodeResponse($response), Response::HTTP_OK, [ + 'Content-Type' => 'application/json', + 'Cache-Control' => 'no-cache', + 'Mcp-Session-Id' => $session->sessionIdentifier->sessionId->toString(), + ], true); + } + + /** + * Clients that no longer need a particular session (e.g., because the user is leaving the client application) SHOULD send an HTTP DELETE to the MCP endpoint with the Mcp-Session-Id header, to explicitly terminate the session. + * @see{https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#session-management} + * + * @param Session $session + * @return Response + */ + public function deleteSession(Session $session): Response + { + $session->delete(); + return new Response(null, Response::HTTP_NO_CONTENT); + } + + /** + * @param Request $request + * @param Session $session + * @return Response + */ + public function initiateSseFromStream(Request $request, Session $session): Response + { + if ($request->headers->has('Last-Event-ID')) { + try { + $session->getStreamIdForEvent($request->headers->get('Last-Event-ID')); + } catch (\InvalidArgumentException $e) { + throw new BadRequestHttpException($e->getMessage()); + } + $lastEventId = $request->headers->get('Last-Event-ID'); + return new StreamedResponse(function () use ($session, $lastEventId) { + $i = 0; + do { + $events = $session->getEventsAfterId($lastEventId); + $lastEvent = null; + foreach ($events as $event) { + $lastEventId = $event['id']; + $lastEvent = $event['event']; + echo 'id: ' . $lastEventId . \PHP_EOL; + echo 'data: ' . $lastEvent . \PHP_EOL . \PHP_EOL; + if (false !== ob_get_length()) { + ob_flush(); + } + flush(); + } + if ($events === []) { + usleep(1000); + } + // @todo we should detect here that the "real" response has been sent and close the stream + } while (! ($lastEvent instanceof \Symfony\AI\McpSdk\Message\Response) && $i++ < 50); + }, headers: [ + 'Content-Type' => 'text/event-stream', + 'Cache-Control' => 'no-cache', + 'X-Accel-Buffering' => 'no', + ]); + + + } else { + // At this point server cannot attach to this stream to send request / notifications, so act like we don't support + return new Response(null, Response::HTTP_METHOD_NOT_ALLOWED); + } + } +} diff --git a/src/mcp-bundle/src/Controller/McpController.php b/src/mcp-bundle/src/Controller/McpSseController.php similarity index 97% rename from src/mcp-bundle/src/Controller/McpController.php rename to src/mcp-bundle/src/Controller/McpSseController.php index 3b036aa4..d6efa6b5 100644 --- a/src/mcp-bundle/src/Controller/McpController.php +++ b/src/mcp-bundle/src/Controller/McpSseController.php @@ -20,7 +20,7 @@ use Symfony\Component\Routing\Generator\UrlGeneratorInterface; use Symfony\Component\Uid\Uuid; -final readonly class McpController +final readonly class McpSseController { public function __construct( private Server $server, diff --git a/src/mcp-bundle/src/McpBundle.php b/src/mcp-bundle/src/McpBundle.php index acd2374d..0da54c73 100644 --- a/src/mcp-bundle/src/McpBundle.php +++ b/src/mcp-bundle/src/McpBundle.php @@ -12,7 +12,8 @@ namespace Symfony\AI\McpBundle; use Symfony\AI\McpBundle\Command\McpCommand; -use Symfony\AI\McpBundle\Controller\McpController; +use Symfony\AI\McpBundle\Controller\McpHttpStreamController; +use Symfony\AI\McpBundle\Controller\McpSseController; use Symfony\AI\McpBundle\Routing\RouteLoader; use Symfony\AI\McpSdk\Capability\Tool\IdentifierInterface; use Symfony\AI\McpSdk\Server\NotificationHandlerInterface; @@ -40,6 +41,7 @@ public function loadExtension(array $config, ContainerConfigurator $container, C $builder->setParameter('mcp.app', $config['app']); $builder->setParameter('mcp.version', $config['version']); $builder->setParameter('mcp.page_size', $config['page_size']); + $builder->setParameter('mcp.http_stream.session.ttl', 3600); if (isset($config['client_transports'])) { $this->configureClient($config['client_transports'], $builder); @@ -52,11 +54,11 @@ public function loadExtension(array $config, ContainerConfigurator $container, C } /** - * @param array{stdio: bool, sse: bool} $transports + * @param array{stdio: bool, sse: bool, http_stream: bool} $transports */ private function configureClient(array $transports, ContainerBuilder $container): void { - if (!$transports['stdio'] && !$transports['sse']) { + if (!$transports['stdio'] && !$transports['sse'] && !$transports['http_stream']) { return; } @@ -74,7 +76,7 @@ private function configureClient(array $transports, ContainerBuilder $container) } if ($transports['sse']) { - $container->register('mcp.server.controller', McpController::class) + $container->register('mcp.server.sse.controller', McpSseController::class) ->setArguments([ new Reference('mcp.server'), new Reference('mcp.server.sse.store.cache_pool'), @@ -84,6 +86,19 @@ private function configureClient(array $transports, ContainerBuilder $container) ->addTag('controller.service_arguments'); } + if ($transports['http_stream']) { + $container->register('mcp.server.http_stream.controller', McpHttpStreamController::class) + ->setArguments([ + new Reference('mcp.server.json_rpc'), + new Reference('mcp.message_factory'), + new Reference('mcp.server.http_stream.session.factory'), + ]) + ->setPublic(true) + ->addTag('controller.service_arguments') + ; + $container->setAlias(McpHttpStreamController::class, 'mcp.server.http_stream.controller'); + } + $container->register('mcp.server.route_loader', RouteLoader::class) ->setArgument(0, $transports['sse']) ->addTag('routing.route_loader'); diff --git a/src/mcp-bundle/src/Session/SessionIdentifierResolver.php b/src/mcp-bundle/src/Session/SessionIdentifierResolver.php new file mode 100644 index 00000000..3d2e1da3 --- /dev/null +++ b/src/mcp-bundle/src/Session/SessionIdentifierResolver.php @@ -0,0 +1,33 @@ +getType() !== SessionIdentifier::class) { + return []; + } + + if (!$request->attributes->has('_mcp_session_id')) { + return match($argument->isNullable()) { + true => [null], + false => [] + }; + } + + $sessionIdentifier = $request->attributes->get('_mcp_session_id'); + if (!$sessionIdentifier instanceof SessionIdentifier) { + throw new InvalidSessionIdException(sprintf('Session "%s" not found.', $sessionIdentifier)); + } + + return [$sessionIdentifier]; + } +} diff --git a/src/mcp-bundle/src/Session/SessionResolver.php b/src/mcp-bundle/src/Session/SessionResolver.php new file mode 100644 index 00000000..841a8908 --- /dev/null +++ b/src/mcp-bundle/src/Session/SessionResolver.php @@ -0,0 +1,33 @@ +getType() !== Session::class) { + return []; + } + + if (!$request->attributes->has('_mcp_session')) { + return match($argument->isNullable()) { + true => [null], + false => throw new InvalidSessionIdException('Session not found.') + }; + } + + $session = $request->attributes->get('_mcp_session'); + if (!$session instanceof Session) { + throw new InvalidSessionIdException('Session not found.'); + } + + return [$session]; + } +} diff --git a/src/mcp-bundle/src/Session/SessionSubscriber.php b/src/mcp-bundle/src/Session/SessionSubscriber.php new file mode 100644 index 00000000..81b7635b --- /dev/null +++ b/src/mcp-bundle/src/Session/SessionSubscriber.php @@ -0,0 +1,49 @@ + 'onKernelRequest', + ]; + } + + public function onKernelRequest(RequestEvent $event): void + { + if (!$event->getRequest()->headers->has('Mcp-Session-Id')) { + return; + } + + try { + $uuid = UuidV4::fromString($event->getRequest()->headers->get('Mcp-Session-Id')); + } catch (InvalidArgumentException) { + throw new BadRequestException(sprintf('Mcp-Session-Id "%s" is not a valid uuid.', $event->getRequest()->headers->get('Mcp-Session-Id'))); + } + + $sessionIdentifier = $this->identifierFactory->get($uuid); + $session = $this->sessionFactory->get($sessionIdentifier); + if (!$session->exists()) { + throw new NotFoundHttpException(sprintf('Session "%s" not found.', $sessionIdentifier)); + } + + $event->getRequest()->attributes->set('_mcp_session_id', $sessionIdentifier); + $event->getRequest()->attributes->set('_mcp_session', $session); + } +} diff --git a/src/mcp-sdk/src/Capability/Tool/ToolExecutorInterface.php b/src/mcp-sdk/src/Capability/Tool/ToolExecutorInterface.php index 5e030662..8e29dd90 100644 --- a/src/mcp-sdk/src/Capability/Tool/ToolExecutorInterface.php +++ b/src/mcp-sdk/src/Capability/Tool/ToolExecutorInterface.php @@ -13,12 +13,14 @@ use Symfony\AI\McpSdk\Exception\ToolExecutionException; use Symfony\AI\McpSdk\Exception\ToolNotFoundException; +use Symfony\AI\McpSdk\Message\Notification; interface ToolExecutorInterface { /** + * @return ToolCallResult|\Traversable * @throws ToolExecutionException if the tool execution fails * @throws ToolNotFoundException if the tool is not found */ - public function call(ToolCall $input): ToolCallResult; + public function call(ToolCall $input): ToolCallResult|\Traversable; } diff --git a/src/mcp-sdk/src/Capability/ToolChain.php b/src/mcp-sdk/src/Capability/ToolChain.php index 4b7235e2..b4c2e8b4 100644 --- a/src/mcp-sdk/src/Capability/ToolChain.php +++ b/src/mcp-sdk/src/Capability/ToolChain.php @@ -58,7 +58,7 @@ public function getMetadata(?int $count, ?string $lastIdentifier = null): iterab } } - public function call(ToolCall $input): ToolCallResult + public function call(ToolCall $input): ToolCallResult|\Traversable { foreach ($this->items as $item) { if ($item instanceof ToolExecutorInterface && $input->name === $item->getName()) { diff --git a/src/mcp-sdk/src/Exception/InvalidSessionIdException.php b/src/mcp-sdk/src/Exception/InvalidSessionIdException.php new file mode 100644 index 00000000..57852311 --- /dev/null +++ b/src/mcp-sdk/src/Exception/InvalidSessionIdException.php @@ -0,0 +1,8 @@ + + * @param string $input + * @return Notification|Request|InvalidInputMessageException * * @throws \JsonException When the input string is not valid JSON */ - public function create(string $input): iterable + public function create(string $input): Notification|Request|InvalidInputMessageException { - $data = json_decode($input, true, flags: \JSON_THROW_ON_ERROR); - - if ('{' === $input[0]) { - $data = [$data]; - } - - foreach ($data as $message) { - if (!isset($message['method'])) { - yield new InvalidInputMessageException('Invalid JSON-RPC request, missing "method".'); - } elseif (str_starts_with((string) $message['method'], 'notifications/')) { - yield Notification::from($message); - } else { - yield Request::from($message); - } + $message = json_decode($input, true, flags: \JSON_THROW_ON_ERROR); + if (!isset($message['method'])) { + return new InvalidInputMessageException('Invalid JSON-RPC request, missing "method".'); + } elseif (str_starts_with((string) $message['method'], 'notifications/')) { + return Notification::from($message); + } else { + return Request::from($message); } } } diff --git a/src/mcp-sdk/src/Message/NotificationHandled.php b/src/mcp-sdk/src/Message/NotificationHandled.php new file mode 100644 index 00000000..c5f99434 --- /dev/null +++ b/src/mcp-sdk/src/Message/NotificationHandled.php @@ -0,0 +1,8 @@ + '2.0', 'id' => $this->id, 'result' => $this->result, ]; + if (null !== $this->method) { + $result['method'] = $this->method; + } + return $result; } } diff --git a/src/mcp-sdk/src/Message/StreamableResponse.php b/src/mcp-sdk/src/Message/StreamableResponse.php new file mode 100644 index 00000000..da62fc2a --- /dev/null +++ b/src/mcp-sdk/src/Message/StreamableResponse.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\McpSdk\Message; + +final class StreamableResponse +{ + /** + * @param callable $responses + */ + public function __construct( + public $responses, + ) { + } +} diff --git a/src/mcp-sdk/src/Server.php b/src/mcp-sdk/src/Server.php index 1b5629b4..0822e749 100644 --- a/src/mcp-sdk/src/Server.php +++ b/src/mcp-sdk/src/Server.php @@ -13,6 +13,7 @@ use Psr\Log\LoggerInterface; use Psr\Log\NullLogger; +use Symfony\AI\McpSdk\Message\StreamableResponse; use Symfony\AI\McpSdk\Server\JsonRpcHandler; use Symfony\AI\McpSdk\Server\TransportInterface; @@ -36,11 +37,15 @@ public function connect(TransportInterface $transport): void } try { - foreach ($this->jsonRpcHandler->process($message) as $response) { - if (null === $response) { - continue; + $response = $this->jsonRpcHandler->process($message); + if (null === $response) { + continue; + } + if ($response instanceof StreamableResponse) { + foreach ($response->responses as $response) { + $transport->send($this->jsonRpcHandler->encodeResponse($response)); } - + } else { $transport->send($response); } } catch (\JsonException $e) { diff --git a/src/mcp-sdk/src/Server/JsonRpcHandler.php b/src/mcp-sdk/src/Server/JsonRpcHandler.php index ed925fdf..3d72e84a 100644 --- a/src/mcp-sdk/src/Server/JsonRpcHandler.php +++ b/src/mcp-sdk/src/Server/JsonRpcHandler.php @@ -20,8 +20,11 @@ use Symfony\AI\McpSdk\Message\Error; use Symfony\AI\McpSdk\Message\Factory; use Symfony\AI\McpSdk\Message\Notification; +use Symfony\AI\McpSdk\Message\NotificationHandled; use Symfony\AI\McpSdk\Message\Request; use Symfony\AI\McpSdk\Message\Response; +use Symfony\AI\McpSdk\Message\StreamableResponse; +use Symfony\AI\McpSdk\Server\RequestHandler\InitializeHandler; /** * @final @@ -53,60 +56,74 @@ public function __construct( } /** - * @return iterable + * @return null|Error|Response|StreamableResponse|iterable * * @throws ExceptionInterface When a handler throws an exception during message processing * @throws \JsonException When JSON encoding of the response fails */ - public function process(string $input): iterable + public function process(string $input): null|Error|Response|StreamableResponse|iterable { $this->logger->info('Received message to process', ['message' => $input]); try { - $messages = $this->messageFactory->create($input); + $message = $this->messageFactory->create($input); } catch (\JsonException $e) { $this->logger->warning('Failed to decode json message', ['exception' => $e]); - yield $this->encodeResponse(Error::parseError($e->getMessage())); - - return; + return Error::parseError($e->getMessage()); } - foreach ($messages as $message) { - if ($message instanceof InvalidInputMessageException) { - $this->logger->warning('Failed to create message', ['exception' => $message]); - yield $this->encodeResponse(Error::invalidRequest(0, $message->getMessage())); - continue; + $response = $this->handleMessage($message); + if (null === $response) { + return null; + } + if ($response instanceof StreamableResponse) { + foreach($response->responses as $response) { + yield $response; } + } else { + return $response; + } + } + + /** + * @param Notification|Request|InvalidInputMessageException $message + * @return Error|Response|StreamableResponse|NotificationHandled|null + */ + public function handleMessage(Notification|Request|InvalidInputMessageException $message): Error|Response|StreamableResponse|NotificationHandled|null + { + if ($message instanceof InvalidInputMessageException) { + $this->logger->warning('Failed to create message', ['exception' => $message]); + return Error::invalidRequest(0, $message->getMessage()); + } - $this->logger->info('Decoded incoming message', ['message' => $message]); + $this->logger->info('Decoded incoming message', ['message' => $message]); - try { - yield $message instanceof Notification - ? $this->handleNotification($message) - : $this->encodeResponse($this->handleRequest($message)); - } catch (\DomainException) { - yield null; - } catch (NotFoundExceptionInterface $e) { - $this->logger->warning(\sprintf('Failed to create response: %s', $e->getMessage()), ['exception' => $e]); + try { + return $message instanceof Notification + ? $this->handleNotification($message) + : $this->handleRequest($message); + } catch (\DomainException) { + return null; + } catch (NotFoundExceptionInterface $e) { + $this->logger->warning(\sprintf('Failed to create response: %s', $e->getMessage()), ['exception' => $e]); - yield $this->encodeResponse(Error::methodNotFound($message->id, $e->getMessage())); - } catch (\InvalidArgumentException $e) { - $this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]); + return Error::methodNotFound($message->id, $e->getMessage()); + } catch (\InvalidArgumentException $e) { + $this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]); - yield $this->encodeResponse(Error::invalidParams($message->id, $e->getMessage())); - } catch (\Throwable $e) { - $this->logger->critical(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]); + return Error::invalidParams($message->id, $e->getMessage()); + } catch (\Throwable $e) { + $this->logger->critical(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]); - yield $this->encodeResponse(Error::internalError($message->id, $e->getMessage())); - } + return Error::internalError($message->id, $e->getMessage()); } } /** * @throws \JsonException When JSON encoding fails */ - private function encodeResponse(Response|Error|null $response): ?string + public function encodeResponse(Response|Error|null $response): ?string { if (null === $response) { $this->logger->warning('Response is null'); @@ -126,7 +143,7 @@ private function encodeResponse(Response|Error|null $response): ?string /** * @throws ExceptionInterface When a notification handler throws an exception */ - private function handleNotification(Notification $notification): null + private function handleNotification(Notification $notification): NotificationHandled { $handled = false; foreach ($this->notificationHandlers as $handler) { @@ -140,14 +157,14 @@ private function handleNotification(Notification $notification): null $this->logger->warning(\sprintf('No handler found for "%s".', $notification->method), ['notification' => $notification]); } - return null; + return new NotificationHandled(); } /** * @throws NotFoundExceptionInterface When no handler is found for the request method * @throws ExceptionInterface When a request handler throws an exception */ - private function handleRequest(Request $request): Response|Error + private function handleRequest(Request $request): StreamableResponse|Response|Error { foreach ($this->requestHandlers as $handler) { if ($handler->supports($request)) { diff --git a/src/mcp-sdk/src/Server/NotificationHandler/InitializedHandler.php b/src/mcp-sdk/src/Server/NotificationHandler/InitializedHandler.php index cff8a3f0..13c210d6 100644 --- a/src/mcp-sdk/src/Server/NotificationHandler/InitializedHandler.php +++ b/src/mcp-sdk/src/Server/NotificationHandler/InitializedHandler.php @@ -12,9 +12,12 @@ namespace Symfony\AI\McpSdk\Server\NotificationHandler; use Symfony\AI\McpSdk\Message\Notification; +use Symfony\AI\McpSdk\Server\Transport\StreamableHttp\Session\Session; final class InitializedHandler extends BaseNotificationHandler { + public function __construct(private readonly ?Session $session = null) { } + protected function supportedNotification(): string { return 'initialized'; @@ -22,5 +25,6 @@ protected function supportedNotification(): string public function handle(Notification $notification): void { + $this->session?->setClientNotificationInitializedReceived(); } } diff --git a/src/mcp-sdk/src/Server/RequestHandler/InitializeHandler.php b/src/mcp-sdk/src/Server/RequestHandler/InitializeHandler.php index f04cdb91..adc4c20f 100644 --- a/src/mcp-sdk/src/Server/RequestHandler/InitializeHandler.php +++ b/src/mcp-sdk/src/Server/RequestHandler/InitializeHandler.php @@ -25,7 +25,7 @@ public function __construct( public function createResponse(Request $message): Response { return new Response($message->id, [ - 'protocolVersion' => '2025-03-26', + 'protocolVersion' => '2025-06-18', 'capabilities' => [ 'prompts' => ['listChanged' => false], 'tools' => ['listChanged' => false], diff --git a/src/mcp-sdk/src/Server/RequestHandler/ToolCallHandler.php b/src/mcp-sdk/src/Server/RequestHandler/ToolCallHandler.php index 49c16562..53c1a715 100644 --- a/src/mcp-sdk/src/Server/RequestHandler/ToolCallHandler.php +++ b/src/mcp-sdk/src/Server/RequestHandler/ToolCallHandler.php @@ -12,12 +12,15 @@ namespace Symfony\AI\McpSdk\Server\RequestHandler; use Symfony\AI\McpSdk\Capability\Tool\ToolCall; +use Symfony\AI\McpSdk\Capability\Tool\ToolCallResult; use Symfony\AI\McpSdk\Capability\Tool\ToolExecutorInterface; use Symfony\AI\McpSdk\Exception\ExceptionInterface; use Symfony\AI\McpSdk\Exception\InvalidArgumentException; use Symfony\AI\McpSdk\Message\Error; +use Symfony\AI\McpSdk\Message\Notification; use Symfony\AI\McpSdk\Message\Request; use Symfony\AI\McpSdk\Message\Response; +use Symfony\AI\McpSdk\Message\StreamableResponse; final class ToolCallHandler extends BaseRequestHandler { @@ -26,17 +29,51 @@ public function __construct( ) { } - public function createResponse(Request $message): Response|Error + public function createResponse(Request $message): StreamableResponse|Response|Error { $name = $message->params['name']; $arguments = $message->params['arguments'] ?? []; try { $result = $this->toolExecutor->call(new ToolCall(uniqid('', true), $name, $arguments)); - } catch (ExceptionInterface) { + } catch (ExceptionInterface $e) { return Error::internalError($message->id, 'Error while executing tool'); } + if ($result instanceof \Traversable) { + return new StreamableResponse( + function () use ($message, $result): \Generator { + foreach ($result as $resultDetail) { + if ($resultDetail instanceof Notification) { + yield [ + 'jsonrpc' => '2.0', + 'method' => 'notifications/progress', + 'params' => [ + 'progress' => 10, + 'progressToken' => $message->params['_meta']['progressToken'], + ] + /*'method' => 'notifications/message', + 'params' => [ + "level" => 'info', + 'data' => 'In progress' + ]*/ + ]; + } elseif ($resultDetail instanceof ToolCallResult) { + yield $this->getResponse($message, $resultDetail); + break; + } else { + throw new InvalidArgumentException('Unsupported tool result type: '.\get_class($resultDetail)); + } + } + } + ); + } + + return $this->getResponse($message, $result); + } + + protected function getResponse(Request $message, ToolCallResult $result): Response + { $content = match ($result->type) { 'text' => [ 'type' => 'text', diff --git a/src/mcp-sdk/src/Server/RequestHandlerInterface.php b/src/mcp-sdk/src/Server/RequestHandlerInterface.php index 7dd08b0c..481bb536 100644 --- a/src/mcp-sdk/src/Server/RequestHandlerInterface.php +++ b/src/mcp-sdk/src/Server/RequestHandlerInterface.php @@ -15,6 +15,7 @@ use Symfony\AI\McpSdk\Message\Error; use Symfony\AI\McpSdk\Message\Request; use Symfony\AI\McpSdk\Message\Response; +use Symfony\AI\McpSdk\Message\StreamableResponse; interface RequestHandlerInterface { @@ -23,5 +24,5 @@ public function supports(Request $message): bool; /** * @throws ExceptionInterface When the handler encounters an error processing the request */ - public function createResponse(Request $message): Response|Error; + public function createResponse(Request $message): StreamableResponse|Response|Error; } diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/Session.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/Session.php new file mode 100644 index 00000000..492415e5 --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/Session.php @@ -0,0 +1,144 @@ +}> + */ + private array $streams; + + private bool $clientNotificationInitializedReceived = false; + + /** + * @var array{string, int} + */ + private array $eventsIdToStreamId; + public function __construct(public readonly SessionIdentifier $sessionIdentifier, private readonly SessionStorageInterface $sessionStorage, array $data = []) { + $this->streams = $data['streams'] ?? []; + $this->eventsIdToStreamId = $data['eventsIdToStreamId'] ?? []; + } + + public function exists(): bool + { + return $this->sessionStorage->exists($this->sessionIdentifier); + } + + public function save(): void + { + $this->sessionStorage->save($this->sessionIdentifier, $this); + } + + public function getData(): array + { + return [ + 'streams' => $this->streams, + 'eventsIdToStreamId' => $this->eventsIdToStreamId, + ]; + } + + public function getStreamUuid(int $streamId): Uuid + { + $this->refreshData(); + if (!isset($this->streams[$streamId])) { + throw new \InvalidArgumentException(sprintf('Stream with id "%s" does not exist', $streamId)); + } + return $this->streams[$streamId]['id']; + } + + public function getEventsOnStream(int $streamId): array + { + $this->refreshData(); + if (!isset($this->streams[$streamId])) { + throw new \InvalidArgumentException(sprintf('Stream with id "%s" does not exist', $streamId)); + } + return $this->streams[$streamId]['messages'] ?? []; + } + + public function addEventOnStream(int $streamId, string $eventId, string $event): void + { + $this->refreshData(); + if (!isset($this->streams[$streamId])) { + throw new \InvalidArgumentException(sprintf('Stream with id "%s" does not exist', $streamId)); + } + $this->streams[$streamId]['events'][] = [ + 'id' => $eventId, + 'event' => $event + ]; + if (count($this->streams[$streamId]['events']) > self::MAX_EVENTS_PER_STREAM) { + array_shift($this->streams[$streamId]['events']); + } + $this->eventsIdToStreamId[$eventId] = $streamId; + $this->save(); + } + + public function getStreamIdForEvent(string $eventId): int + { + $this->refreshData(); + if (!isset($this->eventsIdToStreamId[$eventId])) { + throw new \InvalidArgumentException(sprintf('Event with id "%s" does not exist', $eventId)); + } + return $this->eventsIdToStreamId[$eventId]; + } + + public function getEventsAfterId(string $eventId): array + { + $streamId = $this->getStreamIdForEvent($eventId); + $events = $this->streams[$streamId]['events']; + $eventOffset = array_search($eventId, array_column($events, 'id')); + if ($eventOffset === false) { + return []; + } + return array_slice($events, (int) $eventOffset + 1); + } + + public function addNewStream(bool $clientInitiated = false): int + { + $this->refreshData(); + $streamId = count($this->streams); + $this->streams[$streamId] = [ + 'id' => Uuid::v4(), + 'clientInitiated' => $clientInitiated, + 'events' => [], + ]; + $this->save(); + + return $streamId; + } + + public function isClientNotificationInitializedReceived(): bool + { + $this->refreshData(); + return $this->clientNotificationInitializedReceived; + } + + public function setClientNotificationInitializedReceived(): void + { + $this->refreshData(); + $this->clientNotificationInitializedReceived = true; + $this->save(); + } + + private function refreshData(): void + { + if (!$this->exists()) { + throw new SessionNotFoundException(); + } + $session = $this->sessionStorage->get($this->sessionIdentifier); + $this->streams = $session->streams; + $this->eventsIdToStreamId = $session->eventsIdToStreamId; + } + + public function delete(): void + { + $this->sessionStorage->remove($this->sessionIdentifier); + } +} diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionFactory.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionFactory.php new file mode 100644 index 00000000..12572840 --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionFactory.php @@ -0,0 +1,18 @@ +identifierFactory->get(), $this->storage); + } +} diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionIdentifierFactory.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionIdentifierFactory.php new file mode 100644 index 00000000..48602099 --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionIdentifierFactory.php @@ -0,0 +1,17 @@ +security?->getUser()?->getUserIdentifier()); + } +} diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionPoolStorage.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionPoolStorage.php new file mode 100644 index 00000000..260eb017 --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/Session/SessionPoolStorage.php @@ -0,0 +1,63 @@ +cachePool->hasItem($this->getCacheKey($sessionIdentifier)); + } catch(InvalidArgumentException) { + throw new InvalidSessionIdException(sprintf('Session identifier (id: "%s", user: "%s" is invalid)', $sessionIdentifier->sessionId, $sessionIdentifier->userIdentifier ?? '')); + } + } + + public function save(SessionIdentifier $sessionIdentifier, Session $session): void + { + try { + $item = $this->cachePool->getItem($this->getCacheKey($sessionIdentifier)); + $item->set($session->getData()); + $item->expiresAfter($this->ttlInSeconds); + } catch(InvalidArgumentException) { + throw new InvalidSessionIdException(sprintf('Session identifier (id: "%s", user: "%s" is invalid)', $sessionIdentifier->sessionId, $sessionIdentifier->userIdentifier ?? '')); + } + $this->cachePool->save($item); + } + + public function remove(SessionIdentifier $sessionIdentifier): void + { + try { + $this->cachePool->deleteItem($this->getCacheKey($sessionIdentifier)); + } catch(InvalidArgumentException) { + throw new InvalidSessionIdException(sprintf('Session identifier (id: "%s", user: "%s" is invalid)', $sessionIdentifier->sessionId, $sessionIdentifier->userIdentifier ?? '')); + } + } + + private function getCacheKey(SessionIdentifier $sessionIdentifier): string + { + return sprintf('session_%s_%s', $sessionIdentifier->sessionId->toRfc4122(), $sessionIdentifier->userIdentifier ?? ''); + } + + public function get(SessionIdentifier $sessionIdentifier): Session + { + try { + $item = $this->cachePool->getItem($this->getCacheKey($sessionIdentifier)); + if (!$item->isHit()) { + throw new InvalidSessionIdException(sprintf('Session identifier (id: "%s", user: "%s" is invalid)', $sessionIdentifier->sessionId, $sessionIdentifier->userIdentifier ?? '')); + } + $item->expiresAfter($this->ttlInSeconds); + return new Session($sessionIdentifier, $this, $item->get()); + } catch(InvalidArgumentException) { + throw new InvalidSessionIdException(sprintf('Session identifier (id: "%s", user: "%s" is invalid)', $sessionIdentifier->sessionId, $sessionIdentifier->userIdentifier ?? '')); + } + } +} diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionIdentifier.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionIdentifier.php new file mode 100644 index 00000000..96385301 --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionIdentifier.php @@ -0,0 +1,19 @@ +sessionId->toRfc4122() . ($this->userIdentifier ? '_' . $this->userIdentifier : ''); + } +} diff --git a/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionStorageInterface.php b/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionStorageInterface.php new file mode 100644 index 00000000..b03198ab --- /dev/null +++ b/src/mcp-sdk/src/Server/Transport/StreamableHttp/SessionStorageInterface.php @@ -0,0 +1,37 @@ +