Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 64 additions & 27 deletions OpenAI/GPT-OSS.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
## `gpt-oss` vLLM Usage Guide

`gpt-oss-20b` and `gpt-oss-120b` are powerful reasoning models open-sourced by OpenAI.
In vLLM, you can run it on NVIDIA H100, H200, B200 as well as MI300x, MI325x, MI355x and Radeon AI PRO R9700.
We are actively working on ensuring this model can work on Ampere, Ada Lovelace, and RTX 5090.
`gpt-oss-20b` and `gpt-oss-120b` are powerful reasoning models open-sourced by OpenAI.
In vLLM, you can run it on NVIDIA H100, H200, B200 as well as MI300x, MI325x, MI355x and Radeon AI PRO R9700.
We are actively working on ensuring this model can work on Ampere, Ada Lovelace, and RTX 5090.
Specifically, vLLM optimizes for `gpt-oss` family of models with

* **Flexible parallelism options**: the model can be sharded across 2, 4, 8 GPUs, scaling throughput.
* **High performance attention and MoE kernels**: attention kernel is specifically optimized for the attention sinks mechanism and sliding window shapes.
* **Asynchronous scheduling**: optimizing for maximum utilization and high throughput by overlapping CPU operations with GPU operations.
* **High performance attention and MoE kernels**: attention kernel is specifically optimized for the attention sinks mechanism and sliding window shapes.
* **Asynchronous scheduling**: optimizing for maximum utilization and high throughput by overlapping CPU operations with GPU operations.

This is a living document and we welcome contributions, corrections, and creation of new recipes!
This is a living document and we welcome contributions, corrections, and creation of new recipes!

## Quickstart

Expand Down Expand Up @@ -41,7 +41,7 @@ GPT-OSS works on Ampere devices by default, using the `TRITON_ATTN` attention ba

```
# openai/gpt-oss-20b should run on a single A100
vllm serve openai/gpt-oss-20b --async-scheduling
vllm serve openai/gpt-oss-20b --async-scheduling

# gpt-oss-120b will fit on a single A100 (80GB), but scaling it to higher TP sizes can help with throughput
vllm serve openai/gpt-oss-120b --async-scheduling
Expand All @@ -54,11 +54,11 @@ vllm serve openai/gpt-oss-120b --tensor-parallel-size 4 --async-scheduling
GPT-OSS works on Hopper devices by default, using the FlashAttention3 backend and Marlin MXFP4 MoE:

* `--async-scheduling` can be enabled for higher performance. Currently it is not compatible with structured output.
* We recommend TP=2 for H100 and H200 as the best performance tradeoff point.
* We recommend TP=2 for H100 and H200 as the best performance tradeoff point.

```
# openai/gpt-oss-20b should run in single GPU
vllm serve openai/gpt-oss-20b --async-scheduling
vllm serve openai/gpt-oss-20b --async-scheduling

# gpt-oss-120b will fit in a single H100/H200, but scaling it to higher TP sizes can help with throughput
vllm serve openai/gpt-oss-120b --async-scheduling
Expand All @@ -74,19 +74,19 @@ NVIDIA Blackwell requires installation of [FlashInfer library](https://github.co
uv pip install vllm[flashinfer]==0.10.1 --torch-backend=auto
```

We recommend TP=1 as a starting point for a performant option. We are actively working on the performance of vLLM on Blackwell.
We recommend TP=1 as a starting point for a performant option. We are actively working on the performance of vLLM on Blackwell.

```
# Pick only one out of the two for MoE implementation
# bf16 activation for MoE. matching reference precision (default).
export VLLM_USE_FLASHINFER_MXFP4_BF16_MOE=1
export VLLM_USE_FLASHINFER_MXFP4_BF16_MOE=1
# mxfp8 activation for MoE. faster, but higher risk for accuracy.
export VLLM_USE_FLASHINFER_MXFP4_MOE=1
export VLLM_USE_FLASHINFER_MXFP4_MOE=1

# openai/gpt-oss-20b
vllm serve openai/gpt-oss-20b --async-scheduling
vllm serve openai/gpt-oss-20b --async-scheduling

# gpt-oss-120b
# gpt-oss-120b
vllm serve openai/gpt-oss-120b --async-scheduling
vllm serve openai/gpt-oss-120b --tensor-parallel-size 2 --async-scheduling
vllm serve openai/gpt-oss-120b --tensor-parallel-size 4 --async-scheduling
Expand All @@ -96,8 +96,8 @@ vllm serve openai/gpt-oss-120b --tensor-parallel-size 4 --async-scheduling

ROCm supports OpenAI gpt-oss-120b or gpt-oss-20b models on these 3 different GPUs on day one, along with the pre-built docker containers:

* gfx950: MI350x series, `rocm/vllm-dev:open-mi355-08052025`
* gfx942: MI300x/MI325 series, `rocm/vllm-dev:open-mi300-08052025`
* gfx950: MI350x series, `rocm/vllm-dev:open-mi355-08052025`
* gfx942: MI300x/MI325 series, `rocm/vllm-dev:open-mi300-08052025`
* gfx1201: Radeon AI PRO R9700, `rocm/vllm-dev:open-r9700-08052025`

To run the container:
Expand All @@ -115,7 +115,7 @@ export VLLM_ROCM_USE_AITER=1
export VLLM_USE_AITER_UNIFIED_ATTENTION=1
export VLLM_ROCM_USE_AITER_MHA=0

vllm serve openai/gpt-oss-120b --compilation-config '{"full_cuda_graph": true}'
vllm serve openai/gpt-oss-120b --compilation-config '{"full_cuda_graph": true}'
```

For MI355x:
Expand All @@ -130,22 +130,22 @@ export VLLM_USE_AITER_UNIFIED_ATTENTION=1
export VLLM_ROCM_USE_AITER_MHA=0
export TRITON_HIP_PRESHUFFLE_SCALES=1

vllm serve openai/gpt-oss-120b --compilation-config '{"compile_sizes": [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 4096, 8192], "full_cuda_graph": true}' --block-size 64
vllm serve openai/gpt-oss-120b --compilation-config '{"compile_sizes": [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 4096, 8192], "full_cuda_graph": true}' --block-size 64
```

#### Known Issues
- When you encounter this error `The link interface of target "torch::nvtoolsext" contains: CUDA::nvToolsExt but the target was not found.` Please double check your pytorch version has suffix `+cu128`.
- If the output you see is garbage, that might be because you haven't properly set `CUDA_HOME`. The CUDA version needs to be greater than or equal to 12.8 and must be the same for installation and serving.
- If the output you see is garbage, that might be because you haven't properly set `CUDA_HOME`. The CUDA version needs to be greater than or equal to 12.8 and must be the same for installation and serving.

## Usage

Once the `vllm serve` runs and `INFO: Application startup complete` has been displayed, you can send requests using HTTP request or OpenAI SDK to the following endpoints:

* `/v1/responses` endpoint can perform tool use (browsing, python, mcp) in between chain-of-thought and deliver a final response. This endpoint leverages the `openai-harmony` library for input rendering and output parsing. Stateful operation and full streaming API are work in progress. Responses API is recommended by OpenAI as the way to interact with this model.
* `/v1/chat/completions` endpoint offers a familiar interface to this model. No tool will be invoked but reasoning and final text output will be returned structurally. Function calling is work in progress. You can also set the parameter `include_reasoning: false` in request parameter to skip CoT being part of the output.
* `/v1/completions` endpoint is the endpoint for a simple input output interface without any sorts of template rendering.
* `/v1/completions` endpoint is the endpoint for a simple input output interface without any sorts of template rendering.

All endpoints accept `stream: true` as part of the operations to enable incremental token streaming. Please note that vLLM currently does not cover the full scope of responses API, for more detail, please see Limitation section below.
All endpoints accept `stream: true` as part of the operations to enable incremental token streaming. Please note that vLLM currently does not cover the full scope of responses API, for more detail, please see Limitation section below.

### Tool Use

Expand All @@ -159,8 +159,8 @@ uv pip install gpt-oss
vllm serve ... --tool-server demo
```

* Please note that the default options are simply for demo purposes. For production usage, vLLM itself can act as MCP client to multiple services.
Here is an [example tool server](https://github.com/openai/gpt-oss/tree/main/gpt-oss-mcp-server) that vLLM can work with, they wrap the demo tools:
* Please note that the default options are simply for demo purposes. For production usage, vLLM itself can act as MCP client to multiple services.
Here is an [example tool server](https://github.com/openai/gpt-oss/tree/main/gpt-oss-mcp-server) that vLLM can work with, they wrap the demo tools:

```
mcp run -t sse browser_server.py:mcp
Expand All @@ -169,7 +169,37 @@ mcp run -t sse python_server.py:mcp
vllm serve ... --tool-server ip-1:port-1,ip-2:port-2
```

The URLs are expected to be MCP SSE servers that implement `instructions` in server info and well documented tools. The tools will be injected into the system prompt for the model to enable them.
The URLs are expected to be MCP SSE servers that implement `instructions` in server info and well documented tools. The tools will be injected into the system prompt for the model to enable them.

GPT OSS also expects a built-in tool called container. It doesn't have exposed tool type in openai types.
For reference the container tool is a stateful docker container that can be used to run command line tools.
The enabled tool namespace is `container` and the tool name used the most is `exec`.
MCP server needs to implement the following functions to support container tool:
```
Tool: exec
Arguments:
- cmd (List[str]): command to execute
- workdir (Optional[str]): current working directory
- env (Optional[Dict[str, str]]): environment variables
- session_name (Optional[str]): session name
- timeout (Optional[int]): timeout in seconds
- user (Optional[str]): user name
Signature:
async def exec(ctx: Context, **kwargs) -> str
# Note: `ctx` is expected to contain a session id to identify the container session and make it stateful.
```
Container tool runtime implementation can be referenced from https://github.com/SWE-agent/SWE-ReX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a series of commands to setup the SWE MCP server and run vllm with this MCP server?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have a standalone MCP server but it should be the same as any MCP server to run standalone

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean some instructions similar to the python & browser mcp server:

use (https://github.com/openai/gpt-oss/tree/main/gpt-oss-mcp-server and run

mcp run -t sse browser_server.py:mcp
mcp run -t sse python_server.py:mcp
vllm serve ... --tool-server ip-1:port-1,ip-2:port-2

The docker image might need to have some similar features as codex supports
To enable container tool in vllm before openai types has it, Add below
```
export VLLM_GPT_OSS_USE_CONTAINER_TOOL=1
```
To properly run container tool, follow examples in sample_container_mcp.md
and run
```
mcp run -t sse container_server.py:mcp
```
Note names here are dummy and you need to implement your own.

## Accuracy Evaluation Panels

Expand Down Expand Up @@ -233,15 +263,15 @@ vllm serve openai/gpt-oss-120b --gpu-memory-utilization 0.95 --max-num-batched-t
* Streaming is fairly barebone at the moment, for example:
* Item id and indexing needs more work
* Tool invocation and output are not properly streamed, rather batched.
* Proper error handling is missing.
* Proper error handling is missing.

## Troubleshooting

- Attention sink dtype error on Blackwell:

```
ERROR 08-05 07:31:10 [multiproc_executor.py:559] assert sinks.dtype == torch.float32, "Sinks must be of type float32"
**(VllmWorker TP0 pid=174579)** ERROR 08-05 07:31:10 [multiproc_executor.py:559] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-05 07:31:10 [multiproc_executor.py:559] assert sinks.dtype == torch.float32, "Sinks must be of type float32"
**(VllmWorker TP0 pid=174579)** ERROR 08-05 07:31:10 [multiproc_executor.py:559] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**(VllmWorker TP0 pid=174579)** ERROR 08-05 07:31:10 [multiproc_executor.py:559] AssertionError: Sinks must be of type float32
```

Expand Down Expand Up @@ -272,3 +302,10 @@ Meaning:
If you want to use offline inference, you can treat vLLM as a token-in-token-out service and pass in tokens that are already formatted with Harmony.

For function calling, only tool_choice="auto" is supported.

Harmony also only supports instructions in developer message but to achieve better alignment with training
It it preferred to place it in system message.
Enable below to move instructions to system message
```
export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1
```
144 changes: 144 additions & 0 deletions OpenAI/sample_container_mcp.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to just drop this as python script here and link it with gh link

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both works. But it is very dummy so I think in MD file is ok

Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Container MCP Server Example

This is an incomplete example showing how to implement a container tool for GPT using MCP.

```
from mcp.server.fastmcp import fastmcp
# dummy showing how to import container tool
from swe_rex import SweRexManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to use https://github.com/SWE-agent/SWE-ReX to build this SweRexManager?


# Pass lifespan to server
mcp = FastMCP(
name="container",
instructions=r"""
Utilities for interacting with a container, for example, a Docker container.\n
(container_tool, version 1.X.X)\n
(lean_terminal, version 1.X.X)
""".strip(),
)

swe_rex_manager = SweRexManager()

def _get_session_id(ctx: Context) -> str:
"""Extract session ID from headers, URL query parameter or fallback to client_id"""
request = ctx.request_context.request
return request.headers.get("session_id") or request.query_params.get(
"session_id"
) or ctx.client_id

@mcp.tool(
name="exec",
title="container exec",
description="""
Returns the output of the command.
Allocates an interactive pseudo-TTY if (and only if) 'session_name' is set.
""",
)
async def exec(
ctx: Context,
cmd: list[str],
session_name: Optional[str] = None,
workdir: Optional[str] = None,
timeout: Optional[int] = None,
env: Optional[dict[str, str]] = None,
user: Optional[str] = None,
) -> str:
session_id = _get_session_id(ctx)
try:
logger.debug(f"cmd for container exec: {cmd}")

res = await swe_rex_manager.execute_in_session(
session_id,
cmd=cmd,
workdir=workdir,
env=env,
execution_timeout=360 if timeout is None else timeout,
# Below fields are not used right now
session_name=session_name, # This could be overriding session_id
user=user,
)
logger.info(f"container execution result: {res}")
return res

@mcp.tool(
name="cleanup_session",
title="clean container session",
description="cleanup a specific session",
annotations={
"include_in_prompt": False,
})
async def cleanup_session(ctx: Context) -> None:
"""Cleanup a specific session"""
session_id = _get_session_id(ctx)
logger.info(f"Cleaning up session: {session_id}")
await swe_rex_manager.cleanup_session(session_id)
```

### SweRexManager Implementation Pattern

Based on the RemoteRuntime pattern, your SweRexManager could be implemented like below
Note that this is a dummy implementation and you should implement your own version.
```
from typing import Dict, Any, Optional
import asyncio
from swerex.runtime.remote import RemoteRuntime
from swerex.runtime.config import RemoteRuntimeConfig

class SweRexManager:
def __init__(self, config: Dict[str, Any]):
"""Initialize SweRexManager with dict configuration.

Args:
config: Dictionary containing:
- host: Server host (required)
- port: Server port (optional)
- timeout: Request timeout in seconds (optional, default 30.0)
- auth_token: Authentication token (optional)
"""
self.config = RemoteRuntimeConfig(**config)
self.runtime = RemoteRuntime.from_config(self.config)
self.sessions: Dict[str, str] = {} # session_id -> runtime_session mapping

async def execute_in_session(
self,
session_id: str,
cmd: list[str],
workdir: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
execution_timeout: int = 360,
**kwargs
) -> str:
"""Execute command in a session."""
# Ensure session exists
if session_id not in self.sessions:
await self.create_session(session_id)

from swerex.runtime.abstract import Command

command = Command(
command=cmd,
timeout=execution_timeout,
cwd=workdir,
env=env or {}
)

response = await self.runtime.execute(command)
return response.stdout if response.exit_code == 0 else response.stderr

async def create_session(self, session_id: str) -> None:
"""Create a new session."""
from swerex.runtime.abstract import CreateSessionRequest

request = CreateSessionRequest(session_id=session_id)
await self.runtime.create_session(request)
self.sessions[session_id] = session_id

async def cleanup_session(self, session_id: str) -> None:
"""Cleanup a session."""
if session_id in self.sessions:
from swerex.runtime.abstract import CloseSessionRequest

request = CloseSessionRequest(session_id=session_id)
await self.runtime.close_session(request)
del self.sessions[session_id]
```