Skip to content

Commit e46f700

Browse files
hchingsdavidmlw
andauthored
[trtllm] fix: Fixes for TRTLLM rollout (#5032)
### What does this PR do? Several fixes/improvements for trtllm rollout: - Remove dependency on `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` for before Ray cluster starts. - Fix memory issue during IPC handle creation in `update_weights()`. - Fix reward model path (Note: these codes are for future use when TRTLLM rollout switches from `TRTLLMSampler` to `TorchSampler`. Currently, TRTLLM rollout still doesn't support training involving reward models. ) ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: Liwei Ma <liweim@nvidia.com>
1 parent b53f0f1 commit e46f700

File tree

9 files changed

+52
-48
lines changed

9 files changed

+52
-48
lines changed

.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ jobs:
116116
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
117117
HF_ENDPOINT: "https://hf-mirror.com"
118118
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
119-
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES: "1"
120119
steps:
121120
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
122121
with:
@@ -152,7 +151,6 @@ jobs:
152151
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
153152
HF_ENDPOINT: "https://hf-mirror.com"
154153
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
155-
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES: "1"
156154
steps:
157155
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
158156
with:

docs/workers/trtllm_worker.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ Install verl with TensorRT-LLM:
3636
unset "$v"
3737
done
3838
39-
# Required for IPC UUID detection
40-
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
41-
4239
Using TensorRT-LLM as the Rollout Engine for GRPO
4340
-------------------------------------------------
4441

examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do
77
unset "$v"
88
done
99

10-
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
1110
export RAY_DEDUP_LOGS=0
1211

1312
# -----

examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do
55
unset "$v"
66
done
77

8-
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
98
export RAY_DEDUP_LOGS=0
109

1110
# -----

tests/special_sanity/check_license.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates"
1919
license_head_bytedance_25 = "Copyright 2025 Bytedance Ltd. and/or its affiliates"
20+
license_head_bytedance_26 = "Copyright 2026 Bytedance Ltd. and/or its affiliates"
2021
# Add custom license headers below
2122
license_head_prime = "Copyright 2024 PRIME team and/or its affiliates"
2223
license_head_individual = "Copyright 2025 Individual Contributor:"
@@ -29,6 +30,7 @@
2930
license_headers = [
3031
license_head_bytedance,
3132
license_head_bytedance_25,
33+
license_head_bytedance_26,
3234
license_head_prime,
3335
license_head_individual,
3436
license_head_sglang,

verl/experimental/reward_loop/reward_loop.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,23 @@ async def compute_score_disrm(self, data: DataProto) -> dict:
218218
}
219219
output = await self._post_request(payloads, "v1/embeddings")
220220
rm_score = output["data"][-1]["embedding"][-1]
221+
elif engine_name == "trtllm":
222+
# TODO: remove this once TRT-LLM switches to TorchSampler
223+
raise ValueError("TensorRT-LLM backend does not support reward models currently.")
224+
225+
payloads = {
226+
"model": model_name,
227+
"prompt": disrm_prompt,
228+
"return_context_logits": True,
229+
}
230+
output = await self._post_request(payloads, "v1/completions")
231+
rm_score = output["choices"][0]["context_logits"]
232+
assert isinstance(rm_score, list) and len(rm_score) > 0, (
233+
"TensorRT-LLM OpenAI server response for reward score is not in the expected format."
234+
)
235+
236+
rm_score = float(rm_score[0][0])
237+
logger.debug(f"rm score: {rm_score}")
221238
else:
222239
raise NotImplementedError(f"RewardLoopManager does not support {engine_name}")
223240

verl/workers/rollout/trtllm_rollout/trtllm_async_rollout.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ Note that using the TRT-LLM rollout requires setting the following environment v
2323
for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do
2424
unset "$v"
2525
done
26-
27-
# Required for IPC UUID detection
28-
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
2926
```
3027

3128
## 2. Architecture Design

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Bytedance Ltd. and/or its affiliates
1+
# Copyright 2026 Bytedance Ltd. and/or its affiliates
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,14 +17,14 @@
1717
from typing import Any, Optional
1818

1919
import ray
20+
import torch
2021
from omegaconf import DictConfig
2122
from ray.actor import ActorHandle
2223
from ray.util import placement_group_table
2324
from ray.util.placement_group import PlacementGroup
2425

2526
from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool
2627
from verl.utils.config import omega_conf_to_dataclass
27-
from verl.utils.device import is_cuda_available
2828
from verl.utils.net_utils import is_valid_ipv6_address
2929
from verl.workers.config import HFModelConfig, RolloutConfig
3030
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
@@ -42,6 +42,7 @@ class TRTLLMHttpServer:
4242
Args:
4343
config (DictConfig): full config.
4444
model_config (HFModelConfig): model config.
45+
is_reward_model (bool): whether this is a reward model.
4546
rollout_mode (RolloutMode): rollout mode.
4647
workers (list[ActorHandle]): list of rollout workers.
4748
replica_rank (int): replica rank, a replica may contain multiple nodes.
@@ -54,6 +55,7 @@ def __init__(
5455
self,
5556
config: RolloutConfig,
5657
model_config: HFModelConfig,
58+
is_reward_model: bool,
5759
rollout_mode: RolloutMode,
5860
workers: list[ActorHandle],
5961
replica_rank: int,
@@ -62,10 +64,11 @@ def __init__(
6264
bundle_indices: list[list[int]] = None,
6365
):
6466
os.environ["TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL"] = "1"
65-
assert is_cuda_available, "TRTLLM http server should run on GPU node"
67+
assert torch.cuda.is_available(), "TRTLLM http server should run on GPU node"
6668

6769
self.config: RolloutConfig = omega_conf_to_dataclass(config)
6870
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
71+
self.is_reward_model = is_reward_model
6972
self.config.max_model_len = self.config.prompt_length + self.config.response_length
7073
self.rollout_mode = rollout_mode
7174
self.workers = workers
@@ -82,7 +85,7 @@ def __init__(
8285
self._server_address = ray.util.get_node_ip_address().strip("[]")
8386
self._server_port = None
8487

85-
logger.info(f"TRTLLMHttpServer, replica_rank: {self.replica_rank}, ")
88+
logger.info(f"TRTLLMHttpServer, replica_rank: {self.replica_rank}")
8689

8790
self.sampling_args = {
8891
"detokenize": False,
@@ -107,11 +110,6 @@ async def launch_server(self):
107110
enable_block_reuse=True,
108111
free_gpu_memory_fraction=self.config.gpu_memory_utilization,
109112
)
110-
cuda_graph_config = CudaGraphConfig(
111-
enable_padding=True,
112-
batch_sizes=self.config.cudagraph_capture_sizes,
113-
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
114-
)
115113

116114
per_worker_gpu_share = 1.0 / self.max_colocate_count
117115

@@ -121,7 +119,6 @@ async def launch_server(self):
121119
"orchestrator_type": "ray",
122120
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
123121
"kv_cache_config": kv_cache_config,
124-
"cuda_graph_config": cuda_graph_config,
125122
"max_seq_len": self.config.max_model_len,
126123
"max_batch_size": self.config.max_num_seqs,
127124
"max_num_tokens": self.config.max_num_batched_tokens,
@@ -136,6 +133,24 @@ async def launch_server(self):
136133
**engine_kwargs,
137134
}
138135

136+
if self.is_reward_model:
137+
llm_kwargs.update(
138+
{
139+
"cuda_graph_config": None,
140+
"disable_overlap_scheduler": True,
141+
}
142+
)
143+
else:
144+
llm_kwargs.update(
145+
{
146+
"cuda_graph_config": CudaGraphConfig(
147+
enable_padding=True,
148+
batch_sizes=self.config.cudagraph_capture_sizes,
149+
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
150+
)
151+
}
152+
)
153+
139154
self.llm = await AsyncLLM(**llm_kwargs)
140155

141156
trtllm_server = OpenAIServer(
@@ -313,6 +328,7 @@ async def launch_servers(self):
313328
).remote(
314329
config=self.config,
315330
model_config=self.model_config,
331+
is_reward_model=self.is_reward_model,
316332
rollout_mode=self.rollout_mode,
317333
workers=self.workers,
318334
replica_rank=self.replica_rank,

verl/workers/rollout/trtllm_rollout/trtllm_rollout.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Bytedance Ltd. and/or its affiliates
1+
# Copyright 2026 Bytedance Ltd. and/or its affiliates
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@
3131
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3232
from torch.multiprocessing.reductions import reduce_tensor
3333

34+
from verl.utils.memory_utils import aggressive_empty_cache
3435
from verl.utils.net_utils import is_valid_ipv6_address
3536
from verl.workers.config import HFModelConfig, RolloutConfig
3637
from verl.workers.rollout.base import BaseRollout
@@ -46,21 +47,6 @@
4647
DEFAULT_MAX_WAIT_TIME = 300.0
4748

4849

49-
def device_id_to_physical_device_id(id: int) -> int:
50-
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
51-
if "CUDA_VISIBLE_DEVICES" in os.environ:
52-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
53-
try:
54-
physical_device_id = int(device_ids[id])
55-
return physical_device_id
56-
except (ValueError, IndexError) as err:
57-
raise RuntimeError(
58-
f"Failed to convert logical device ID {id} to physical device ID. Available devices are: {device_ids}."
59-
) from err
60-
else:
61-
return id
62-
63-
6450
@contextlib.contextmanager
6551
def nvml_context():
6652
"""Context manager for NVML initialization and shutdown.
@@ -95,24 +81,19 @@ def get_device_uuid(id: int) -> str:
9581
except pynvml.NVMLError as e:
9682
raise RuntimeError(f"Failed to initialize NVML: {e}") from e
9783

98-
# The process has visibility to all GPUs within the TP group
99-
global_device_idx = device_id_to_physical_device_id(id)
100-
10184
# Get the device handle and UUID
10285
try:
103-
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
86+
handle = pynvml.nvmlDeviceGetHandleByIndex(id)
10487
uuid = pynvml.nvmlDeviceGetUUID(handle)
10588
# Ensure the UUID is returned as a string, not bytes
10689
if isinstance(uuid, bytes):
10790
return uuid.decode("utf-8")
10891
elif isinstance(uuid, str):
10992
return uuid
11093
else:
111-
raise RuntimeError(
112-
f"Unexpected UUID type: {type(uuid)} for device {id} (global index: {global_device_idx})"
113-
)
94+
raise RuntimeError(f"Unexpected UUID type: {type(uuid)} for device {id} (global index: {id})")
11495
except pynvml.NVMLError as e:
115-
raise RuntimeError(f"Failed to get device UUID for device {id} (global index: {global_device_idx}): {e}") from e
96+
raise RuntimeError(f"Failed to get device UUID for device {id} (global index: {id}): {e}") from e
11697

11798

11899
async def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]:
@@ -330,8 +311,6 @@ def __init__(
330311
assert self.replica_rank >= 0, "replica_rank is not set"
331312
assert self.is_leader_rank is not None, "is_leader_rank is not set"
332313

333-
print(f"ServerAdapter, replica_rank: {self.replica_rank}, is_leader_rank: {self.is_leader_rank}")
334-
335314
self.node_ip = ray.util.get_node_ip_address().strip("[]")
336315

337316
async def _init_server_adapter(self):
@@ -412,8 +391,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
412391
try:
413392
device_uuid = get_device_uuid(self.gpu_id)
414393
except Exception as e:
415-
logger.error(f"Failed to get device UUID: {e}")
416-
logger.error("Did you miss to set RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 before ray start?")
394+
logger.error(f"Failed to get device UUID in update_weights(): {e}")
417395
device_uuid = None
418396
raise e
419397

@@ -447,3 +425,4 @@ async def flush():
447425
# Finalize update weights
448426
await self._adapter.update_weights(None)
449427
await asyncio.to_thread(dist.barrier, group=self.hybrid_device_mesh["exclude_dp"].get_group())
428+
aggressive_empty_cache(force_sync=False)

0 commit comments

Comments
 (0)