Skip to content

Commit 9594aea

Browse files
committed
Merge branch 'openrlhf' of github.com:vllm-project/vllm-project.github.io into openrlhf
2 parents 4ebcb7e + 62ca284 commit 9594aea

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

_posts/2025-04-23-openrlhf-vllm.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ To strike a balance between performance and usability in RLHF frameworks, [OpenR
1919

2020
**ZeRO-3 with [HuggingFace Transformers](https://github.com/huggingface/transformers)**, a memory optimization approach from [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), empowers OpenRLHF to train large models without requiring heavyweight frameworks like Megatron. This seamless integration with HuggingFace allows for simple loading and fine-tuning of pre-trained models.
2121

22-
Together, Ray, vLLM, ZeRO-3, and HuggingFace Transformers create a cutting-edge yet streamlined solution for accelerating RLHF training. The architecture has also influenced other frameworks such as [veRL](https://github.com/volcengine/verl), which adopt similar paradigms for scalable and efficient RLHF training. OpenRLHF is also the first open-source RLHF framework developed based on Ray and vLLM, and has been used by Google, Bytedance, Alibaba, Meituan, Berkeley Starling Team etc.
22+
Together, Ray, vLLM, ZeRO-3, and HuggingFace Transformers create a cutting-edge yet streamlined solution for accelerating RLHF training. The architecture has also influenced other frameworks such as [veRL](https://github.com/volcengine/verl), which adopt similar paradigms for scalable and efficient RLHF training. OpenRLHF is also the first open-source RLHF framework developed based on Ray, vLLM and ZeRO-3, and has been used by Google, Bytedance, Alibaba, Meituan, Berkeley Starling Team etc.
2323

2424
<img align="center" src="/assets/figures/openrlhf-vllm/ray.png" alt="Ray and vLLM in OpenRLHF" width="90%" height="90%">
2525

@@ -30,6 +30,7 @@ As illustrated above, OpenRLHF uses [Ray’s Placement Group API](https://docs.r
3030
OpenRLHF and vLLM provide a clean and efficient set of APIs to simplify interaction within RLHF pipelines. By implementing a custom `WorkerExtension` class, users can handle weight synchronization between training and inference components. The environment variables `VLLM_RAY_PER_WORKER_GPUS` and `VLLM_RAY_BUNDLE_INDICES` allows fine-grained GPU resource allocation per worker, enabling hybrid engine configurations where multiple components share a GPU group:
3131

3232
```python
33+
# rlhf_utils.py
3334
class ColocateWorkerExtension:
3435
"""
3536
Extension class for vLLM workers to handle weight synchronization.
@@ -55,6 +56,7 @@ class ColocateWorkerExtension:
5556
self.model_runner.model.load_weights(weights=weights)
5657
torch.cuda.synchronize()
5758

59+
# main.py
5860
class MyLLM(LLM):
5961
"""
6062
Custom LLM class to handle GPU resource allocation and bundle indices.
@@ -69,7 +71,7 @@ class MyLLM(LLM):
6971
super().__init__(*args, **kwargs)
7072

7173

72-
# Create placement group for GPU allocation
74+
# Create Ray's placement group for GPU allocation
7375
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
7476
ray.get(pg.ready())
7577

@@ -86,7 +88,7 @@ for bundle_indices in [[0, 1], [2, 3]]:
8688
tensor_parallel_size=2,
8789
distributed_executor_backend="ray",
8890
gpu_memory_utilization=0.4,
89-
worker_extension_cls="__main__.ColocateWorkerExtension",
91+
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
9092
bundle_indices=bundle_indices
9193
)
9294
inference_engines.append(llm)

0 commit comments

Comments
 (0)