Skip to content

Commit da94c7c

Browse files
authored
Move online quantization to model.load_weights (#26327)
Signed-off-by: Jerry Zhang <[email protected]>
1 parent 1395461 commit da94c7c

File tree

6 files changed

+314
-113
lines changed

6 files changed

+314
-113
lines changed

examples/offline_inference/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, *args, **kwargs):
6262

6363
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
6464
# Learn more about Ray placement groups:
65-
# https://docs.ray.io/en/latest/placement-groups.html
65+
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
6666
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
6767
ray.get(pg_inference.ready())
6868
scheduling_inference = PlacementGroupSchedulingStrategy(
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
5+
6+
The script separates training and inference workloads onto distinct GPUs
7+
so that Ray can manage process placement and inter-process communication.
8+
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
9+
tensor-parallel vLLM inference engine occupies GPU 1–2.
10+
11+
The example performs the following steps:
12+
13+
* Load the training model on GPU 0.
14+
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
15+
and Ray placement groups.
16+
* Generate text from a list of prompts using the inference engine.
17+
* Update the weights of the training model and broadcast the updated weights
18+
to the inference engine by using a Ray collective RPC group. Note that
19+
for demonstration purposes we simply zero out the weights.
20+
21+
For a production-ready implementation that supports multiple training and
22+
inference replicas, see the OpenRLHF framework:
23+
https://github.com/OpenRLHF/OpenRLHF
24+
25+
This example assumes a single-node cluster with three GPUs, but Ray
26+
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
27+
workloads. Residual GPU activity interferes with vLLM memory profiling and
28+
causes unexpected behavior.
29+
"""
30+
31+
import json
32+
import os
33+
34+
import ray
35+
import torch
36+
from ray.util.placement_group import placement_group
37+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
38+
from rlhf_utils import stateless_init_process_group
39+
from torchao.core.config import config_to_dict
40+
from torchao.quantization import (
41+
Float8DynamicActivationFloat8WeightConfig,
42+
PerRow,
43+
)
44+
from transformers import AutoModelForCausalLM
45+
46+
from vllm import LLM, SamplingParams
47+
from vllm.utils.network_utils import get_ip, get_open_port
48+
49+
50+
class MyLLM(LLM):
51+
"""Configure the vLLM worker for Ray placement group execution."""
52+
53+
def __init__(self, *args, **kwargs):
54+
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
55+
# so that vLLM can manage its own device placement within the worker.
56+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
57+
super().__init__(*args, **kwargs)
58+
59+
60+
# Load the OPT-125M model onto GPU 0 for the training workload.
61+
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
62+
train_model.to("cuda:0")
63+
64+
# Initialize Ray and set the visible devices. The vLLM engine will
65+
# be placed on GPUs 1 and 2.
66+
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
67+
ray.init()
68+
69+
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
70+
# Learn more about Ray placement groups:
71+
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
72+
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
73+
ray.get(pg_inference.ready())
74+
scheduling_inference = PlacementGroupSchedulingStrategy(
75+
placement_group=pg_inference,
76+
placement_group_capture_child_tasks=True,
77+
placement_group_bundle_index=0,
78+
)
79+
80+
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
81+
# start-up latency.
82+
83+
# generate torchao quantization config for RL rollout
84+
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
85+
# use serialized config files instead of passing around json string
86+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
87+
88+
json_str = json.dumps(config_to_dict(config))
89+
90+
llm = ray.remote(
91+
num_cpus=0,
92+
num_gpus=0,
93+
scheduling_strategy=scheduling_inference,
94+
)(MyLLM).remote(
95+
model="facebook/opt-125m",
96+
hf_overrides={"quantization_config_dict_json": json_str},
97+
enforce_eager=True,
98+
worker_extension_cls="rlhf_utils.WorkerExtension",
99+
tensor_parallel_size=2,
100+
distributed_executor_backend="ray",
101+
)
102+
103+
# Generate text from the prompts.
104+
prompts = [
105+
"Hello, my name is",
106+
"The president of the United States is",
107+
"The capital of France is",
108+
"The future of AI is",
109+
]
110+
111+
sampling_params = SamplingParams(temperature=0)
112+
113+
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
114+
115+
print("-" * 50)
116+
for output in outputs:
117+
prompt = output.prompt
118+
generated_text = output.outputs[0].text
119+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
120+
print("-" * 50)
121+
122+
# Set up the communication channel between the training process and the
123+
# inference engine.
124+
master_address = get_ip()
125+
master_port = get_open_port()
126+
127+
handle = llm.collective_rpc.remote(
128+
"init_weight_update_group", args=(master_address, master_port, 1, 3)
129+
)
130+
131+
model_update_group = stateless_init_process_group(
132+
master_address, master_port, 0, 3, torch.device("cuda:0")
133+
)
134+
ray.get(handle)
135+
136+
# Simulate a training step by zeroing out all model weights.
137+
# In a real RLHF training loop the weights would be updated using the gradient
138+
# from an RL objective such as PPO on a reward model.
139+
for name, p in train_model.named_parameters():
140+
p.data.zero_()
141+
142+
# Synchronize the updated weights to the inference engine.
143+
for name, p in train_model.named_parameters():
144+
dtype_name = str(p.dtype).split(".")[-1]
145+
handle = llm.collective_rpc.remote(
146+
"update_weight", args=(name, dtype_name, p.shape)
147+
)
148+
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
149+
ray.get(handle)
150+
151+
# Verify that the inference weights have been updated.
152+
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
153+
154+
# Generate text with the updated model. The output is expected to be nonsense
155+
# because the weights are zero.
156+
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
157+
print("-" * 50)
158+
for output in outputs_updated:
159+
prompt = output.prompt
160+
generated_text = output.outputs[0].text
161+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
162+
print("-" * 50)

vllm/model_executor/model_loader/default_loader.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
fastsafetensors_weights_iterator,
2323
filter_duplicate_safetensors_files,
2424
filter_files_not_needed_for_inference,
25+
get_quant_config,
2526
maybe_download_from_modelscope,
2627
multi_thread_pt_weights_iterator,
2728
multi_thread_safetensors_weights_iterator,
@@ -273,42 +274,17 @@ def download_model(self, model_config: ModelConfig) -> None:
273274
)
274275

275276
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
276-
if model_config.quantization == "torchao" and torchao_version_at_least(
277-
"0.14.0"
278-
):
279-
self.load_config.safetensors_load_strategy = "torchao"
280-
weights_to_load = {name for name, _ in model.named_parameters()}
281-
282-
# if we don't have `model.weight_metadata_and_attr_saved` defined and
283-
# set to True, it means that this is either offline quantization case
284-
# or the first run of online quantization
285-
# see online_quantization.py for detailed notes
286-
offline_quantization_or_first_run_of_online_quantization = not getattr(
287-
model, "weight_metadata_and_attr_saved", False
288-
)
277+
if model_config.quantization == "torchao":
278+
quant_config = get_quant_config(model_config, self.load_config)
279+
if (
280+
hasattr(quant_config, "is_checkpoint_torchao_serialized")
281+
and quant_config.is_checkpoint_torchao_serialized
282+
and torchao_version_at_least("0.14.0")
283+
):
284+
self.load_config.safetensors_load_strategy = "torchao"
289285

290-
if model_config.quantization is None:
291-
# model is not quantized
292-
loaded_weights = model.load_weights(
293-
self.get_all_weights(model_config, model)
294-
)
295-
elif offline_quantization_or_first_run_of_online_quantization:
296-
# case 1: offline quantized checkpoint
297-
# case 2: Step I1 first run of weight loading with
298-
# online quantization
299-
# see online_quantization.py for detailed notes
300-
loaded_weights = model.load_weights(
301-
self.get_all_weights(model_config, model)
302-
)
303-
else:
304-
# to avoid circular dependency
305-
from vllm.model_executor.model_loader.online_quantization import (
306-
load_weights_and_online_quantize,
307-
)
308-
309-
# subsequent runs of weight loading with online
310-
# quantization
311-
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
286+
weights_to_load = {name for name, _ in model.named_parameters()}
287+
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
312288

313289
self.counter_after_loading_weights = time.perf_counter()
314290
logger.info_once(

0 commit comments

Comments
 (0)