Skip to content

Commit 038e9be

Browse files
andylolu2jeejeelee
andauthored
[LoRA] Much faster startup when LoRA is enabled (#23777)
Signed-off-by: Andy Lo <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent 68a3491 commit 038e9be

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,6 +2213,7 @@ def _dummy_run(
22132213
uniform_decode: bool = False,
22142214
skip_eplb: bool = False,
22152215
is_profile: bool = False,
2216+
remove_lora: bool = True,
22162217
) -> tuple[torch.Tensor, torch.Tensor]:
22172218
"""
22182219
Run a dummy forward pass to warm up/profile run or capture the
@@ -2230,6 +2231,7 @@ def _dummy_run(
22302231
uniform_decode: If True, the batch is a uniform decode batch.
22312232
skip_eplb: If True, skip EPLB state update.
22322233
is_profile: If True, this is a profile run.
2234+
remove_lora: If False, dummy LoRAs are not destroyed after the run
22332235
"""
22342236
assert cudagraph_runtime_mode in {
22352237
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
@@ -2317,7 +2319,7 @@ def _dummy_run(
23172319
attn_metadata[layer_name] = attn_metadata_i
23182320

23192321
with self.maybe_dummy_run_with_lora(self.lora_config,
2320-
num_scheduled_tokens):
2322+
num_scheduled_tokens, remove_lora):
23212323
if self.supports_mm_inputs:
23222324
input_ids = None
23232325
inputs_embeds = self.inputs_embeds[:num_tokens]
@@ -2708,11 +2710,14 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
27082710
cudagraph_runtime_mode=CUDAGraphMode.NONE,
27092711
force_attention=force_attention,
27102712
uniform_decode=uniform_decode,
2711-
skip_eplb=True)
2713+
skip_eplb=True,
2714+
remove_lora=False)
27122715
self._dummy_run(num_tokens,
27132716
cudagraph_runtime_mode=cudagraph_runtime_mode,
27142717
uniform_decode=uniform_decode,
2715-
skip_eplb=True)
2718+
skip_eplb=True,
2719+
remove_lora=False)
2720+
self.maybe_remove_all_loras(self.lora_config)
27162721

27172722
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
27182723
"""

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ def compile_or_warm_up_model(self) -> None:
308308
# We skip EPLB here since we don't want to record dummy metrics
309309
for size in sorted(warmup_sizes, reverse=True):
310310
logger.info("Compile and warming up model for size %d", size)
311-
self.model_runner._dummy_run(size, skip_eplb=True)
311+
self.model_runner._dummy_run(size,
312+
skip_eplb=True,
313+
remove_lora=False)
314+
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
312315

313316
# Warmup and tune the kernels used during model execution before
314317
# cuda graph capture.

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from contextlib import contextmanager
8-
from typing import Union
8+
from typing import Optional, Union
99

1010
import numpy as np
1111
import torch
@@ -87,7 +87,9 @@ def set_active_loras(self, input_batch: InputBatch,
8787
lora_requests)
8888

8989
@contextmanager
90-
def maybe_setup_dummy_loras(self, lora_config):
90+
def maybe_setup_dummy_loras(self,
91+
lora_config: Optional[LoRAConfig],
92+
remove_lora: bool = True):
9193
if lora_config is None:
9294
yield
9395
else:
@@ -114,10 +116,11 @@ def maybe_setup_dummy_loras(self, lora_config):
114116
yield
115117

116118
# __exit__ code
117-
self.lora_manager.remove_all_adapters()
119+
if remove_lora:
120+
self.lora_manager.remove_all_adapters()
118121

119122
@contextmanager
120-
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
123+
def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig],
121124
num_scheduled_tokens: np.ndarray):
122125
if lora_config is None:
123126
yield
@@ -151,13 +154,22 @@ def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
151154
yield
152155

153156
@contextmanager
154-
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
155-
num_scheduled_tokens: np.ndarray):
156-
with self.maybe_setup_dummy_loras(
157-
lora_config), self.maybe_select_dummy_loras(
158-
lora_config, num_scheduled_tokens):
157+
def maybe_dummy_run_with_lora(self,
158+
lora_config: Optional[LoRAConfig],
159+
num_scheduled_tokens: np.ndarray,
160+
remove_lora: bool = True):
161+
with (
162+
self.maybe_setup_dummy_loras(lora_config, remove_lora),
163+
self.maybe_select_dummy_loras(lora_config,
164+
num_scheduled_tokens),
165+
):
159166
yield
160167

168+
def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
169+
if lora_config is None:
170+
return
171+
self.lora_manager.remove_all_adapters()
172+
161173
def add_lora(self, lora_request: LoRARequest) -> bool:
162174
if not self.lora_manager:
163175
raise RuntimeError("LoRA is not enabled.")

0 commit comments

Comments
 (0)