Skip to content

Commit cb3e73e

Browse files
sleepwalker2017weilong.yujeejeelee
authored
[BugFix] fix wrong output when using lora and num_scheduler_steps=8 (#11161)
FIX issue #9688 #11086 #12487 --------- Signed-off-by: Jee Jee Li <[email protected]> Co-authored-by: weilong.yu <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent b1340f9 commit cb3e73e

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

vllm/worker/model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,10 @@ def _dummy_run(self,
13461346

13471347
self.execute_model(model_input, kv_caches, intermediate_tensors)
13481348
torch.cuda.synchronize()
1349+
if self.lora_config:
1350+
# Remove dummy loras.
1351+
assert self.lora_manager is not None
1352+
self.remove_all_loras()
13491353
return
13501354

13511355
def remove_all_loras(self):

vllm/worker/worker.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
264264
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
265265

266266
logger.info(msg)
267-
268267
# Final cleanup
269-
if self.model_runner.lora_manager:
270-
self.model_runner.remove_all_loras()
271268
gc.collect()
272269

273270
return num_gpu_blocks, num_cpu_blocks

0 commit comments

Comments
 (0)