|
| 1 | +--- |
| 2 | +layout: post |
| 3 | +title: "Accelerating RLHF with vLLM Ray Executor (OpenRLHF)" |
| 4 | +author: "The OpenRLHF Team" |
| 5 | +image: /assets/figures/openrlhf-vllm/ray.png |
| 6 | +thumbnail-img: /assets/figures/openrlhf-vllm/ray.png |
| 7 | +share-img: /assets/figures/openrlhf-vllm/ray.png |
| 8 | +--- |
| 9 | + |
| 10 | +As the demand for training reasoning large language models (LLMs) grows, Reinforcement Learning from Human Feedback (RLHF) has become a pivotal technique. However, traditional RLHF training pipelines, especially those involving Proximal Policy Optimization (PPO), often face significant computational bottlenecks, with long chain-of-thought generation consuming up to 90% of the total training time. |
| 11 | + |
| 12 | +## Design Philosophy |
| 13 | + |
| 14 | +To address these challenges, OpenRLHF is designed as a user-friendly, high-performance framework for Reinforcement Learning from Human Feedback (RLHF), integrating key technologies such as Ray, vLLM, ZeRO-3, and AutoTP: |
| 15 | + |
| 16 | +**Ray** serves as the backbone for distributed programming within OpenRLHF. Its robust scheduling and orchestration capabilities make it ideal for managing the complex data flows and computations inherent in RLHF training, including the distribution of reward models across multiple nodes. |
| 17 | + |
| 18 | +**vLLM with Ray Executor and Auto TP** is central to accelerating inference within OpenRLHF. It naturally supports Ray Executors and integrates with Hugging Face Transformers, enabling efficient weight updates through AutoTP. This combination ensures high-throughput, memory-efficient serving of large language models. |
| 19 | + |
| 20 | +**ZeRO-3**, a memory optimization strategy from DeepSpeed, enables OpenRLHF to train large-scale models without the need for complex frameworks like Megatron. This allows for seamless integration with Hugging Face Transformers, facilitating straightforward loading and fine-tuning of pre-trained models. |
| 21 | + |
| 22 | +By combining Ray, vLLM, ZeRO-3, and Hugging Face Transformers, OpenRLHF offers a leading solution for accelerating RLHF training. This architecture has influenced other frameworks, such as veRL, which adopt a similar paradigm for efficient and scalable RLHF training. |
| 23 | + |
| 24 | +<img align="center" src="/assets/figures/openrlhf-vllm/ray.png" alt="Ray and vLLM in OpenRLHF" width="90%" height="90%"> |
| 25 | + |
| 26 | +As illustrated in the figure, OpenRLHF utilizes Ray's placement group API to flexibly schedule various RLHF components, including the vLLM engine, Actor, Critic, Reference, and Reward models. While these models are depicted separately, they can be co-located within shared placement groups to optimize resource utilization. For instance, all modules can share the same GPU group in a Hybrid Engine configuration, or specific components like the Actor and Critic can be assigned to the same GPU group. Weight synchronization between the Actor and the vLLM engine is achieved through high-performance communication mechanisms such as NVIDIA's NCCL or CUDA IPC memory copying, particularly in Hybrid Engine setups. |
| 27 | + |
| 28 | +## Implementing RLHF Acceleration with vLLM Ray Executor |
| 29 | + |
| 30 | +vLLM provides examples demonstrating how to accelerate RLHF training using Ray. By defining a custom `WorkerExtension` class, users can implement logic for weight synchronization between training and inference components. The `VLLM_RAY_PER_WORKER_GPUS` environment variable facilitates the allocation of GPU resources per worker, enabling configurations like hybrid engines where multiple components share the same GPU group. |
| 31 | + |
| 32 | +[An example](https://docs.vllm.ai/en/latest/getting_started/examples/rlhf_colocate.html) setup involves initializing Ray with a specified number of GPUs, creating a placement group for resource allocation, and defining training actors and inference engines. The training actors handle model initialization and weight updates, while the inference engines serve the models using vLLM. Weight synchronization between these components is achieved through inter-process communication mechanisms like CUDA IPC or NCCL, ensuring consistency across the training pipeline. |
| 33 | + |
| 34 | + |
| 35 | +```python |
| 36 | +import os |
| 37 | +import ray |
| 38 | +import torch |
| 39 | +from ray.util.placement_group import placement_group |
| 40 | +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 41 | +from vllm import LLM |
| 42 | +from transformers import AutoModelForCausalLM |
| 43 | + |
| 44 | +class ColocateWorkerExtension: |
| 45 | + """ |
| 46 | + Extension class for vLLM workers to handle weight synchronization. |
| 47 | + This class ensures compatibility with both vLLM V0 and V1. |
| 48 | + """ |
| 49 | + def report_device_id(self) -> str: |
| 50 | + """Report the unique device ID for this worker""" |
| 51 | + from vllm.platforms import current_platform |
| 52 | + self.device_uuid = current_platform.get_device_uuid(self.device.index) |
| 53 | + return self.device_uuid |
| 54 | + |
| 55 | + def update_weights_from_ipc_handles(self, ipc_handles): |
| 56 | + """Update model weights using IPC handles""" |
| 57 | + handles = ipc_handles[self.device_uuid] |
| 58 | + device_id = self.device.index |
| 59 | + weights = [] |
| 60 | + for name, handle in handles.items(): |
| 61 | + func, args = handle |
| 62 | + list_args = list(args) |
| 63 | + list_args[6] = device_id # Update device ID for current process |
| 64 | + tensor = func(*list_args) |
| 65 | + weights.append((name, tensor)) |
| 66 | + self.model_runner.model.load_weights(weights=weights) |
| 67 | + torch.cuda.synchronize() |
| 68 | + |
| 69 | + def check_weights_changed(self): |
| 70 | + """Verify if weights have been updated""" |
| 71 | + return all(torch.allclose(p, torch.zeros_like(p)) |
| 72 | + for p in self.model_runner.model.parameters()) |
| 73 | + |
| 74 | +class MyLLM(LLM): |
| 75 | + """ |
| 76 | + Custom LLM class to handle GPU resource allocation and bundle indices. |
| 77 | + This ensures proper GPU utilization and placement group management. |
| 78 | + """ |
| 79 | + def __init__(self, *args, bundle_indices: list, **kwargs): |
| 80 | + # Prevent Ray from manipulating CUDA_VISIBLE_DEVICES at the top level |
| 81 | + os.environ.pop("CUDA_VISIBLE_DEVICES", None) |
| 82 | + # Configure GPU utilization per worker |
| 83 | + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" |
| 84 | + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) |
| 85 | + super().__init__(*args, **kwargs) |
| 86 | + |
| 87 | +class TrainingActor: |
| 88 | + """ |
| 89 | + Actor class for model training. |
| 90 | + Handles model initialization and weight synchronization. |
| 91 | + """ |
| 92 | + def __init__(self): |
| 93 | + self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") |
| 94 | + self.model.to("cuda:0") |
| 95 | + # Initialize weights to zero for demonstration |
| 96 | + for p in self.model.parameters(): |
| 97 | + p.data.zero_() |
| 98 | + torch.cuda.synchronize() |
| 99 | + from vllm.platforms import current_platform |
| 100 | + self.device_uuid = current_platform.get_device_uuid(0) |
| 101 | + |
| 102 | + def report_device_id(self) -> str: |
| 103 | + return self.device_uuid |
| 104 | + |
| 105 | + def get_weight_ipc_handles(self): |
| 106 | + """Get IPC handles for model weights""" |
| 107 | + from torch.multiprocessing.reductions import reduce_tensor |
| 108 | + return {self.device_uuid: {name: reduce_tensor(p.detach()) |
| 109 | + for name, p in self.model.named_parameters()}} |
| 110 | + |
| 111 | +# Initialize Ray with 4 GPUs |
| 112 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" |
| 113 | +ray.init() |
| 114 | + |
| 115 | +# Create placement group for GPU allocation |
| 116 | +pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) |
| 117 | +ray.get(pg.ready()) |
| 118 | + |
| 119 | +# Create training actors |
| 120 | +training_actors = [] |
| 121 | +for i in range(4): |
| 122 | + actor = ray.remote( |
| 123 | + num_gpus=0.4, |
| 124 | + scheduling_strategy=PlacementGroupSchedulingStrategy( |
| 125 | + placement_group=pg, |
| 126 | + placement_group_bundle_index=i |
| 127 | + ) |
| 128 | + )(TrainingActor).remote() |
| 129 | + training_actors.append(actor) |
| 130 | + |
| 131 | +# Create inference engines |
| 132 | +inference_engines = [] |
| 133 | +for bundle_indices in [[0, 1], [2, 3]]: |
| 134 | + llm = ray.remote( |
| 135 | + num_gpus=0, |
| 136 | + scheduling_strategy=PlacementGroupSchedulingStrategy( |
| 137 | + placement_group=pg |
| 138 | + ) |
| 139 | + )(MyLLM).remote( |
| 140 | + model="facebook/opt-125m", |
| 141 | + tensor_parallel_size=2, |
| 142 | + distributed_executor_backend="ray", |
| 143 | + gpu_memory_utilization=0.4, |
| 144 | + worker_extension_cls="__main__.ColocateWorkerExtension", |
| 145 | + bundle_indices=bundle_indices |
| 146 | + ) |
| 147 | + inference_engines.append(llm) |
| 148 | + |
| 149 | +# Collect device IDs for verification |
| 150 | +training_device_ids = [ray.get(actor.report_device_id.remote()) |
| 151 | + for actor in training_actors] |
| 152 | +inference_device_ids = [ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())) |
| 153 | + for llm in inference_engines] |
| 154 | + |
| 155 | +# Verify device placement |
| 156 | +assert training_device_ids[:2] == inference_device_ids[0] |
| 157 | +assert training_device_ids[2:] == inference_device_ids[1] |
| 158 | + |
| 159 | +# Synchronize weights |
| 160 | +ipc_handles = {} |
| 161 | +for actor in training_actors: |
| 162 | + ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) |
| 163 | + |
| 164 | +for llm in inference_engines: |
| 165 | + ray.get(llm.collective_rpc.remote("update_weights_from_ipc_handles", |
| 166 | + args=(ipc_handles,))) |
| 167 | + assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) |
| 168 | +``` |
0 commit comments