|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. |
| 5 | +
|
| 6 | +The script separates training and inference workloads onto distinct GPUs |
| 7 | +so that Ray can manage process placement and inter-process communication. |
| 8 | +A Hugging Face Transformer model occupies GPU 0 for training, whereas a |
| 9 | +tensor-parallel vLLM inference engine occupies GPU 1–2. |
| 10 | +
|
| 11 | +The example performs the following steps: |
| 12 | +
|
| 13 | +* Load the training model on GPU 0. |
| 14 | +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism |
| 15 | + and Ray placement groups. |
| 16 | +* Generate text from a list of prompts using the inference engine. |
| 17 | +* Update the weights of the training model and broadcast the updated weights |
| 18 | + to the inference engine by using a Ray collective RPC group. Note that |
| 19 | + for demonstration purposes we simply zero out the weights. |
| 20 | +
|
| 21 | +For a production-ready implementation that supports multiple training and |
| 22 | +inference replicas, see the OpenRLHF framework: |
| 23 | +https://github.com/OpenRLHF/OpenRLHF |
| 24 | +
|
| 25 | +This example assumes a single-node cluster with three GPUs, but Ray |
| 26 | +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM |
| 27 | +workloads. Residual GPU activity interferes with vLLM memory profiling and |
| 28 | +causes unexpected behavior. |
| 29 | +""" |
| 30 | + |
| 31 | +import json |
| 32 | +import os |
| 33 | + |
| 34 | +import ray |
| 35 | +import torch |
| 36 | +from ray.util.placement_group import placement_group |
| 37 | +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 38 | +from rlhf_utils import stateless_init_process_group |
| 39 | +from torchao.core.config import config_to_dict |
| 40 | +from torchao.quantization import ( |
| 41 | + Float8DynamicActivationFloat8WeightConfig, |
| 42 | + PerRow, |
| 43 | +) |
| 44 | +from transformers import AutoModelForCausalLM |
| 45 | + |
| 46 | +from vllm import LLM, SamplingParams |
| 47 | +from vllm.utils.network_utils import get_ip, get_open_port |
| 48 | + |
| 49 | + |
| 50 | +class MyLLM(LLM): |
| 51 | + """Configure the vLLM worker for Ray placement group execution.""" |
| 52 | + |
| 53 | + def __init__(self, *args, **kwargs): |
| 54 | + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray |
| 55 | + # so that vLLM can manage its own device placement within the worker. |
| 56 | + os.environ.pop("CUDA_VISIBLE_DEVICES", None) |
| 57 | + super().__init__(*args, **kwargs) |
| 58 | + |
| 59 | + |
| 60 | +# Load the OPT-125M model onto GPU 0 for the training workload. |
| 61 | +train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") |
| 62 | +train_model.to("cuda:0") |
| 63 | + |
| 64 | +# Initialize Ray and set the visible devices. The vLLM engine will |
| 65 | +# be placed on GPUs 1 and 2. |
| 66 | +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" |
| 67 | +ray.init() |
| 68 | + |
| 69 | +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. |
| 70 | +# Learn more about Ray placement groups: |
| 71 | +# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html |
| 72 | +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) |
| 73 | +ray.get(pg_inference.ready()) |
| 74 | +scheduling_inference = PlacementGroupSchedulingStrategy( |
| 75 | + placement_group=pg_inference, |
| 76 | + placement_group_capture_child_tasks=True, |
| 77 | + placement_group_bundle_index=0, |
| 78 | +) |
| 79 | + |
| 80 | +# Launch the vLLM inference engine. The `enforce_eager` flag reduces |
| 81 | +# start-up latency. |
| 82 | + |
| 83 | +# generate torchao quantization config for RL rollout |
| 84 | +# see https://github.com/vllm-project/vllm/pull/23014 for instructions to |
| 85 | +# use serialized config files instead of passing around json string |
| 86 | +config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
| 87 | + |
| 88 | +json_str = json.dumps(config_to_dict(config)) |
| 89 | + |
| 90 | +llm = ray.remote( |
| 91 | + num_cpus=0, |
| 92 | + num_gpus=0, |
| 93 | + scheduling_strategy=scheduling_inference, |
| 94 | +)(MyLLM).remote( |
| 95 | + model="facebook/opt-125m", |
| 96 | + hf_overrides={"quantization_config_dict_json": json_str}, |
| 97 | + enforce_eager=True, |
| 98 | + worker_extension_cls="rlhf_utils.WorkerExtension", |
| 99 | + tensor_parallel_size=2, |
| 100 | + distributed_executor_backend="ray", |
| 101 | +) |
| 102 | + |
| 103 | +# Generate text from the prompts. |
| 104 | +prompts = [ |
| 105 | + "Hello, my name is", |
| 106 | + "The president of the United States is", |
| 107 | + "The capital of France is", |
| 108 | + "The future of AI is", |
| 109 | +] |
| 110 | + |
| 111 | +sampling_params = SamplingParams(temperature=0) |
| 112 | + |
| 113 | +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) |
| 114 | + |
| 115 | +print("-" * 50) |
| 116 | +for output in outputs: |
| 117 | + prompt = output.prompt |
| 118 | + generated_text = output.outputs[0].text |
| 119 | + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") |
| 120 | + print("-" * 50) |
| 121 | + |
| 122 | +# Set up the communication channel between the training process and the |
| 123 | +# inference engine. |
| 124 | +master_address = get_ip() |
| 125 | +master_port = get_open_port() |
| 126 | + |
| 127 | +handle = llm.collective_rpc.remote( |
| 128 | + "init_weight_update_group", args=(master_address, master_port, 1, 3) |
| 129 | +) |
| 130 | + |
| 131 | +model_update_group = stateless_init_process_group( |
| 132 | + master_address, master_port, 0, 3, torch.device("cuda:0") |
| 133 | +) |
| 134 | +ray.get(handle) |
| 135 | + |
| 136 | +# Simulate a training step by zeroing out all model weights. |
| 137 | +# In a real RLHF training loop the weights would be updated using the gradient |
| 138 | +# from an RL objective such as PPO on a reward model. |
| 139 | +for name, p in train_model.named_parameters(): |
| 140 | + p.data.zero_() |
| 141 | + |
| 142 | +# Synchronize the updated weights to the inference engine. |
| 143 | +for name, p in train_model.named_parameters(): |
| 144 | + dtype_name = str(p.dtype).split(".")[-1] |
| 145 | + handle = llm.collective_rpc.remote( |
| 146 | + "update_weight", args=(name, dtype_name, p.shape) |
| 147 | + ) |
| 148 | + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) |
| 149 | + ray.get(handle) |
| 150 | + |
| 151 | +# Verify that the inference weights have been updated. |
| 152 | +assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) |
| 153 | + |
| 154 | +# Generate text with the updated model. The output is expected to be nonsense |
| 155 | +# because the weights are zero. |
| 156 | +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) |
| 157 | +print("-" * 50) |
| 158 | +for output in outputs_updated: |
| 159 | + prompt = output.prompt |
| 160 | + generated_text = output.outputs[0].text |
| 161 | + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") |
| 162 | + print("-" * 50) |
0 commit comments