Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tests/entrypoints/openai/test_response_api_mcp_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import pytest_asyncio
from openai import OpenAI

from ...utils import RemoteOpenAIServer

MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()


@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]

with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]

with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
"code_interpreter,container")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client


@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens == 0
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Multiply 64548*15151 using builtin python interpreter.",
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "code_interpreter",
"container": {
Expand All @@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0


def get_weather(latitude, longitude):
Expand Down
37 changes: 30 additions & 7 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union

from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent

from vllm.entrypoints.harmony_utils import (
Expand All @@ -21,6 +22,24 @@

logger = logging.getLogger(__name__)

# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}


def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}")
return _TOOL_NAME_TO_TYPE_MAP[tool_name]


class TurnTokens:
"""Tracks token counts for a single conversation turn."""
Expand Down Expand Up @@ -59,8 +78,8 @@ def render_for_completion(self) -> list[int]:

@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -96,8 +115,8 @@ def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")

async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass

async def cleanup_session(self) -> None:
Expand Down Expand Up @@ -318,13 +337,17 @@ async def call_python_tool(self, tool_session: Union["ClientSession",
]

async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = mcp_tools[
tool_type].headers if tool_type in mcp_tools else None
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id))
tool_server.new_session(tool_name, request_id,
headers))
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ def get_developer_message(
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter",
"container"):
"container", "mcp"):
# These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass

elif tool.type == "function":
Expand Down
25 changes: 19 additions & 6 deletions vllm/entrypoints/openai/serving_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,12 @@ async def responses_full_generator(

async with AsyncExitStack() as exit_stack:
try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
async for _ in result_generator:
pass
except asyncio.CancelledError:
Expand Down Expand Up @@ -748,11 +752,16 @@ def _construct_input_messages_with_harmony(
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container")

# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
for tool in request.tools:
if (tool.type == "mcp" and tool.server_label
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
tool_types.append(tool.server_label)
enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("browser"))
Expand Down Expand Up @@ -1653,8 +1662,12 @@ def _increment_sequence_number_and_return(
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events
Expand Down
29 changes: 20 additions & 9 deletions vllm/entrypoints/tool_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
async def list_server_and_tools(server_url: str):
from mcp import ClientSession
from mcp.client.sse import sse_client

async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
Expand Down Expand Up @@ -86,8 +85,12 @@ def get_tool_description(self,
pass

@abstractmethod
def new_session(self, tool_name: str,
session_id: str) -> AbstractAsyncContextManager[Any]:
def new_session(
self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None
) -> AbstractAsyncContextManager[Any]:
"""
Create a session for the tool.
"""
Expand Down Expand Up @@ -144,16 +147,21 @@ def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)

@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
headers = {"x-session-id": session_id}
request_headers = {"x-session-id": session_id}
if headers is not None:
request_headers.update(headers)
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url,
headers=headers) as streams, ClientSession(
*streams) as session:
async with sse_client(
url=url, headers=request_headers) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session

Expand Down Expand Up @@ -189,7 +197,10 @@ def get_tool_description(self,
raise ValueError(f"Unknown tool {tool_name}")

@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]
Loading