Skip to content

Commit 5c9b807

Browse files
authored
[Core] Add reload_weights RPC method (#20096)
Signed-off-by: 22quinn <[email protected]>
1 parent 14bf19e commit 5c9b807

File tree

5 files changed

+51
-34
lines changed

5 files changed

+51
-34
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
460460
{"load_config": {
461461
"load_format": original_load_format
462462
}})
463-
model_runner_2.load_model() # Load real weights inplace
463+
model_runner_2.reload_weights() # Load real weights inplace
464464
assert str(model_runner.get_model().state_dict()) == str(
465465
model_runner_2.get_model().state_dict())
466466

467467

468+
def test_reload_weights_before_load_model(model_runner):
469+
with pytest.raises(AssertionError):
470+
model_runner.reload_weights()
471+
472+
468473
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
469474
torch.set_default_dtype(torch.float16)
470475
layer_0 = "model.layers.0.self_attn.attn"

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,17 +1873,9 @@ def load_model(self, eep_scale_up: bool = False) -> None:
18731873
with DeviceMemoryProfiler() as m:
18741874
time_before_load = time.perf_counter()
18751875
model_loader = get_model_loader(self.load_config)
1876-
if not hasattr(self, "model"):
1877-
logger.info("Loading model from scratch...")
1878-
self.model = model_loader.load_model(
1879-
vllm_config=self.vllm_config,
1880-
model_config=self.model_config)
1881-
else:
1882-
logger.info(
1883-
"Model was already initialized. Loading weights inplace..."
1884-
)
1885-
model_loader.load_weights(self.model,
1886-
model_config=self.model_config)
1876+
logger.info("Loading model from scratch...")
1877+
self.model = model_loader.load_model(
1878+
vllm_config=self.vllm_config, model_config=self.model_config)
18871879
if self.lora_config:
18881880
self.model = self.load_lora_model(self.model,
18891881
self.model_config,
@@ -1916,6 +1908,13 @@ def load_model(self, eep_scale_up: bool = False) -> None:
19161908
rank_mapping,
19171909
)
19181910

1911+
def reload_weights(self) -> None:
1912+
assert getattr(self, "model", None) is not None, \
1913+
"Cannot reload weights before model is loaded."
1914+
model_loader = get_model_loader(self.load_config)
1915+
logger.info("Reloading weights inplace...")
1916+
model_loader.load_weights(self.model, model_config=self.model_config)
1917+
19191918
def save_tensorized_model(
19201919
self,
19211920
tensorizer_config: "TensorizerConfig",

vllm/v1/worker/gpu_worker.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import copy
55
import gc
66
import os
7+
from contextlib import AbstractContextManager, nullcontext
78
from typing import TYPE_CHECKING, Any, Optional
89

910
import torch
@@ -118,6 +119,21 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
118119
buffer.data.copy_(self._sleep_saved_buffers[name].data)
119120
self._sleep_saved_buffers = {}
120121

122+
def _maybe_get_memory_pool_context(self,
123+
tag: str) -> AbstractContextManager:
124+
if self.vllm_config.model_config.enable_sleep_mode:
125+
from vllm.device_allocator.cumem import CuMemAllocator
126+
127+
allocator = CuMemAllocator.get_instance()
128+
if tag == "weights":
129+
assert allocator.get_current_usage() == 0, (
130+
"Sleep mode can only be "
131+
"used for one instance per process.")
132+
context = allocator.use_memory_pool(tag=tag)
133+
else:
134+
context = nullcontext()
135+
return context
136+
121137
def initialize_cache(self, num_gpu_blocks: int,
122138
num_cpu_blocks: int) -> None:
123139
self.cache_config.num_gpu_blocks = num_gpu_blocks
@@ -179,24 +195,17 @@ def init_device(self):
179195
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
180196
# to hijack tensor allocation.
181197
def load_model(self) -> None:
182-
if self.vllm_config.model_config.enable_sleep_mode:
183-
from vllm.device_allocator.cumem import CuMemAllocator
184-
185-
allocator = CuMemAllocator.get_instance()
186-
assert allocator.get_current_usage() == 0, (
187-
"Sleep mode can only be "
188-
"used for one instance per process.")
189-
context = allocator.use_memory_pool(tag="weights")
190-
else:
191-
from contextlib import nullcontext
192-
context = nullcontext()
193198
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
194-
with context:
199+
with self._maybe_get_memory_pool_context(tag="weights"):
195200
self.model_runner.load_model(eep_scale_up=eep_scale_up)
196201

197202
def update_config(self, overrides: dict[str, Any]) -> None:
198203
self.model_runner.update_config(overrides)
199204

205+
def reload_weights(self) -> None:
206+
with self._maybe_get_memory_pool_context(tag="weights"):
207+
self.model_runner.reload_weights()
208+
200209
@torch.inference_mode()
201210
def determine_available_memory(self) -> int:
202211
"""Profiles the peak memory usage of the model to determine how much

vllm/v1/worker/tpu_model_runner.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,16 +1174,10 @@ def load_model(self) -> None:
11741174
mesh=self.mesh)
11751175
else:
11761176
model_loader = get_model_loader(self.load_config)
1177-
if not hasattr(self, "model"):
1178-
logger.info("Loading model from scratch...")
1179-
model = model_loader.load_model(
1180-
vllm_config=self.vllm_config,
1181-
model_config=self.model_config)
1182-
else:
1183-
logger.info("Model was already initialized. \
1184-
Loading weights inplace...")
1185-
model_loader.load_weights(
1186-
self.model, model_config=self.model_config)
1177+
logger.info("Loading model from scratch...")
1178+
model = model_loader.load_model(
1179+
vllm_config=self.vllm_config,
1180+
model_config=self.model_config)
11871181
except RuntimeError as e:
11881182
raise RuntimeError(
11891183
f"Unable to load model, a likely reason is the model is "
@@ -1205,6 +1199,13 @@ def load_model(self) -> None:
12051199
self.model = model
12061200
self.sampler = TPUSampler()
12071201

1202+
def reload_weights(self) -> None:
1203+
assert getattr(self, "model", None) is not None, \
1204+
"Cannot reload weights before model is loaded."
1205+
model_loader = get_model_loader(self.load_config)
1206+
logger.info("Reloading weights inplace...")
1207+
model_loader.load_weights(self.model, model_config=self.model_config)
1208+
12081209
@torch.no_grad()
12091210
def _dummy_run(self, num_tokens: int, num_reqs: int,
12101211
num_blocks: int) -> None:

vllm/v1/worker/tpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ def load_model(self) -> None:
265265
def update_config(self, overrides: dict[str, Any]) -> None:
266266
self.model_runner.update_config(overrides)
267267

268+
def reload_weights(self) -> None:
269+
self.model_runner.reload_weights()
270+
268271
def compile_or_warm_up_model(self) -> None:
269272
if not self.model_config.enforce_eager:
270273
self.model_runner.capture_model()

0 commit comments

Comments
 (0)