Skip to content

Commit fe6d825

Browse files
authored
[gpt-oss] Support tool call and implement MCP tool server (#22427)
Signed-off-by: Chen Zhang <[email protected]>
1 parent e290594 commit fe6d825

File tree

4 files changed

+234
-83
lines changed

4 files changed

+234
-83
lines changed

vllm/entrypoints/harmony_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
237237
id=f"rs_{random_uuid()}",
238238
summary=[],
239239
type="reasoning",
240-
text=content.text,
240+
content=[
241+
ResponseReasoningTextContent(text=content.text,
242+
type="reasoning_text")
243+
],
241244
status=None,
242245
)
243246
output_items.append(reasoning_item)

vllm/entrypoints/openai/api_server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@
9494
from vllm.entrypoints.openai.serving_transcription import (
9595
OpenAIServingTranscription, OpenAIServingTranslation)
9696
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
97-
from vllm.entrypoints.tool_server import DemoToolServer, ToolServer
97+
from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
98+
ToolServer)
9899
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
99100
log_non_default_args, with_cancellation)
100101
from vllm.logger import init_logger
@@ -1635,6 +1636,9 @@ async def init_app_state(
16351636

16361637
if args.tool_server == "demo":
16371638
tool_server: Optional[ToolServer] = DemoToolServer()
1639+
elif args.tool_server:
1640+
tool_server = MCPToolServer()
1641+
await tool_server.add_tool_server(args.tool_server)
16381642
else:
16391643
tool_server = None
16401644

vllm/entrypoints/openai/serving_responses.py

Lines changed: 107 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import time
66
from collections.abc import AsyncGenerator, AsyncIterator
7+
from contextlib import AsyncExitStack
78
from copy import copy
89
from http import HTTPStatus
910
from typing import Any, Callable, Final, Optional, Union
@@ -226,99 +227,125 @@ async def create_responses(
226227

227228
# Schedule the request and get the result generator.
228229
generators: list[AsyncGenerator[ConversationContext, None]] = []
229-
try:
230-
tool_sessions: dict[str, Any] = {}
231-
for i, engine_prompt in enumerate(engine_prompts):
232-
default_max_tokens = self.max_model_len - len(
233-
engine_prompt["prompt_token_ids"])
234-
sampling_params = request.to_sampling_params(
235-
default_max_tokens, self.default_sampling_params)
236-
237-
trace_headers = (None if raw_request is None else await
238-
self._get_trace_headers(raw_request.headers))
239-
240-
context: ConversationContext
241-
if self.use_harmony:
242-
if request.stream:
243-
context = StreamingHarmonyContext(
244-
messages, tool_sessions)
245-
else:
246-
context = HarmonyContext(messages, tool_sessions)
230+
231+
builtin_tool_list: list[str] = []
232+
if self.use_harmony and self.tool_server is not None:
233+
if self.tool_server.has_tool("browser"):
234+
builtin_tool_list.append("browser")
235+
if self.tool_server.has_tool("python"):
236+
builtin_tool_list.append("python")
237+
async with AsyncExitStack() as exit_stack:
238+
try:
239+
if self.tool_server is not None:
240+
# TODO: initialize tool sessions lazily when the session
241+
# is actually used.
242+
tool_session_ctxs: dict[str, Any] = {
243+
tool_name:
244+
exit_stack.enter_async_context(
245+
self.tool_server.new_session(tool_name))
246+
for tool_name in builtin_tool_list
247+
}
248+
tool_sessions = {}
249+
for tool_name in builtin_tool_list:
250+
tool_sessions[tool_name] = (
251+
await tool_session_ctxs[tool_name])
247252
else:
248-
context = SimpleContext()
249-
generator = self._generate_with_builtin_tools(
250-
request_id=request.request_id,
251-
request_prompt=request_prompts[i],
252-
engine_prompt=engine_prompt,
253-
sampling_params=sampling_params,
254-
context=context,
255-
lora_request=lora_request,
256-
priority=request.priority,
257-
trace_headers=trace_headers,
253+
assert len(builtin_tool_list) == 0
254+
tool_sessions = {}
255+
for i, engine_prompt in enumerate(engine_prompts):
256+
default_max_tokens = self.max_model_len - len(
257+
engine_prompt["prompt_token_ids"])
258+
sampling_params = request.to_sampling_params(
259+
default_max_tokens, self.default_sampling_params)
260+
261+
trace_headers = (None if raw_request is None else await
262+
self._get_trace_headers(
263+
raw_request.headers))
264+
265+
context: ConversationContext
266+
if self.use_harmony:
267+
if request.stream:
268+
context = StreamingHarmonyContext(
269+
messages, tool_sessions)
270+
else:
271+
context = HarmonyContext(messages, tool_sessions)
272+
else:
273+
context = SimpleContext()
274+
generator = self._generate_with_builtin_tools(
275+
request_id=request.request_id,
276+
request_prompt=request_prompts[i],
277+
engine_prompt=engine_prompt,
278+
sampling_params=sampling_params,
279+
context=context,
280+
lora_request=lora_request,
281+
priority=request.priority,
282+
trace_headers=trace_headers,
283+
)
284+
generators.append(generator)
285+
except ValueError as e:
286+
# TODO: Use a vllm-specific Validation Error
287+
return self.create_error_response(str(e))
288+
289+
assert len(generators) == 1
290+
result_generator, = generators
291+
292+
# Store the input messages.
293+
if request.store:
294+
self.msg_store[request.request_id] = messages
295+
296+
if request.background:
297+
created_time = int(time.time())
298+
response = ResponsesResponse.from_request(
299+
request,
300+
sampling_params,
301+
model_name=model_name,
302+
created_time=created_time,
303+
output=[],
304+
status="queued",
305+
usage=None,
258306
)
259-
generators.append(generator)
260-
except ValueError as e:
261-
# TODO: Use a vllm-specific Validation Error
262-
return self.create_error_response(str(e))
307+
async with self.response_store_lock:
308+
self.response_store[response.id] = response
263309

264-
assert len(generators) == 1
265-
result_generator, = generators
310+
# Run the request in the background.
311+
task = asyncio.create_task(
312+
self._run_background_request(
313+
request,
314+
sampling_params,
315+
result_generator,
316+
context,
317+
model_name,
318+
tokenizer,
319+
request_metadata,
320+
created_time,
321+
),
322+
name=f"create_{response.id}",
323+
)
266324

267-
# Store the input messages.
268-
if request.store:
269-
self.msg_store[request.request_id] = messages
325+
# For cleanup.
326+
response_id = response.id
327+
self.background_tasks[response_id] = task
328+
task.add_done_callback(
329+
lambda _: self.background_tasks.pop(response_id, None))
330+
return response
270331

271-
if request.background:
272-
created_time = int(time.time())
273-
response = ResponsesResponse.from_request(
274-
request,
275-
sampling_params,
276-
model_name=model_name,
277-
created_time=created_time,
278-
output=[],
279-
status="queued",
280-
usage=None,
281-
)
282-
async with self.response_store_lock:
283-
self.response_store[response.id] = response
332+
if request.stream:
333+
raise NotImplementedError(
334+
"Streaming responses are not supported")
284335

285-
# Run the request in the background.
286-
task = asyncio.create_task(
287-
self._run_background_request(
336+
try:
337+
return await self.responses_full_generator(
288338
request,
289339
sampling_params,
290340
result_generator,
291341
context,
292342
model_name,
293343
tokenizer,
294344
request_metadata,
295-
created_time,
296-
),
297-
name=f"create_{response.id}",
298-
)
299-
300-
# For cleanup.
301-
response_id = response.id
302-
self.background_tasks[response_id] = task
303-
task.add_done_callback(
304-
lambda _: self.background_tasks.pop(response_id, None))
305-
return response
306-
307-
if request.stream:
308-
raise NotImplementedError("Streaming responses are not supported")
309-
310-
try:
311-
return await self.responses_full_generator(
312-
request,
313-
sampling_params,
314-
result_generator,
315-
context,
316-
model_name,
317-
tokenizer,
318-
request_metadata,
319-
)
320-
except Exception as e:
321-
return self.create_error_response(str(e))
345+
)
346+
except Exception as e:
347+
return self.create_error_response(str(e))
348+
return self.create_error_response("Should not reach here")
322349

323350
async def _make_request(
324351
self,

vllm/entrypoints/tool_server.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from abc import ABC, abstractmethod
44
from contextlib import AbstractAsyncContextManager, asynccontextmanager
5-
from typing import Any, Optional
5+
from typing import TYPE_CHECKING, Any, Optional
66

77
from openai_harmony import ToolNamespaceConfig
88

@@ -11,6 +11,61 @@
1111

1212
logger = init_logger(__name__)
1313

14+
if TYPE_CHECKING:
15+
from mcp.types import ListToolsResult
16+
17+
18+
async def list_server_and_tools(server_url: str):
19+
from mcp import ClientSession
20+
from mcp.client.sse import sse_client
21+
22+
async with sse_client(url=server_url) as streams, ClientSession(
23+
*streams) as session:
24+
initialize_response = await session.initialize()
25+
list_tools_response = await session.list_tools()
26+
return initialize_response, list_tools_response
27+
28+
29+
def trim_schema(schema: dict) -> dict:
30+
# Turn JSON Schema from MCP generated into Harmony's variant.
31+
if "title" in schema:
32+
del schema["title"]
33+
if "default" in schema and schema["default"] is None:
34+
del schema["default"]
35+
if "anyOf" in schema:
36+
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
37+
# into "type": ["type-1", "type-2"]
38+
# if there's more than 1 types, also remove "null" type as Harmony will
39+
# just ignore it
40+
types = [
41+
type_dict["type"] for type_dict in schema["anyOf"]
42+
if type_dict["type"] != 'null'
43+
]
44+
schema["type"] = types
45+
del schema["anyOf"]
46+
if "properties" in schema:
47+
schema["properties"] = {
48+
k: trim_schema(v)
49+
for k, v in schema["properties"].items()
50+
}
51+
return schema
52+
53+
54+
def post_process_tools_description(
55+
list_tools_result: "ListToolsResult") -> "ListToolsResult":
56+
# Adapt the MCP tool result for Harmony
57+
for tool in list_tools_result.tools:
58+
tool.inputSchema = trim_schema(tool.inputSchema)
59+
60+
# Some tools schema don't need to be part of the prompt (e.g. simple text
61+
# in text out for Python)
62+
list_tools_result.tools = [
63+
tool for tool in list_tools_result.tools
64+
if getattr(tool.annotations, "include_in_prompt", True)
65+
]
66+
67+
return list_tools_result
68+
1469

1570
class ToolServer(ABC):
1671

@@ -38,6 +93,66 @@ def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
3893
...
3994

4095

96+
class MCPToolServer(ToolServer):
97+
98+
def __init__(self):
99+
try:
100+
import mcp # noqa: F401
101+
except ImportError:
102+
raise ImportError(
103+
"mcp is not installed. Please run `pip install mcp` to use "
104+
"MCPToolServer.") from None
105+
self.harmony_tool_descriptions = {}
106+
107+
async def add_tool_server(self, server_url: str):
108+
from mcp.types import ToolDescription
109+
tool_urls = server_url.split(",")
110+
self.harmony_tool_descriptions = {}
111+
self.urls: dict[str, str] = {}
112+
for url in tool_urls:
113+
url = f"http://{url}/sse"
114+
initialize_response, list_tools_response = (
115+
await list_server_and_tools(url))
116+
117+
list_tools_response = post_process_tools_description(
118+
list_tools_response)
119+
120+
tool_from_mcp = ToolNamespaceConfig(
121+
name=initialize_response.serverInfo.name,
122+
description=initialize_response.instructions,
123+
tools=[
124+
ToolDescription.new(name=tool.name,
125+
description=tool.description,
126+
parameters=tool.inputSchema)
127+
for tool in list_tools_response.tools
128+
])
129+
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
130+
if tool_from_mcp.name not in self.urls:
131+
self.urls[tool_from_mcp.name] = url
132+
else:
133+
logger.warning(
134+
"Tool %s already exists. Ignoring duplicate tool server %s",
135+
tool_from_mcp.name, url)
136+
137+
def has_tool(self, tool_name: str):
138+
return tool_name in self.harmony_tool_descriptions
139+
140+
def get_tool_description(self, tool_name: str):
141+
return self.harmony_tool_descriptions.get(tool_name)
142+
143+
@asynccontextmanager
144+
async def new_session(self, tool_name: str):
145+
from mcp import ClientSession
146+
from mcp.client.sse import sse_client
147+
url = self.urls.get(tool_name)
148+
if not url:
149+
raise KeyError(f"Tool '{tool_name}' is not supported")
150+
async with sse_client(url=url) as streams, ClientSession(
151+
*streams) as session:
152+
await session.initialize()
153+
yield session
154+
155+
41156
class DemoToolServer(ToolServer):
42157

43158
def __init__(self):
@@ -67,4 +182,6 @@ def get_tool_description(self,
67182

68183
@asynccontextmanager
69184
async def new_session(self, tool_name: str):
185+
if tool_name not in self.tools:
186+
raise KeyError(f"Tool '{tool_name}' is not supported")
70187
yield self.tools[tool_name]

0 commit comments

Comments
 (0)