Skip to content

Commit b72af8f

Browse files
zhaoyang-starzhaoyang-star
andauthored
Fix error when tp > 1 (#2644)
Co-authored-by: zhaoyang-star <[email protected]>
1 parent 9090bf0 commit b72af8f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

vllm/engine/llm_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
236236
model_config = copy.deepcopy(self.model_config)
237237
parallel_config = copy.deepcopy(self.parallel_config)
238238
scheduler_config = copy.deepcopy(self.scheduler_config)
239-
cache_config = copy.deepcopy(self.cache_config)
240239

241240
for rank, (worker, (node_id,
242241
_)) in enumerate(zip(self.workers,
@@ -252,7 +251,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
252251
rank,
253252
distributed_init_method,
254253
lora_config=self.lora_config,
255-
cache_config=cache_config,
254+
kv_cache_dtype=self.cache_config.cache_dtype,
256255
))
257256

258257
driver_rank = 0
@@ -265,7 +264,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
265264
driver_rank,
266265
distributed_init_method,
267266
lora_config=self.lora_config,
268-
cache_config=cache_config,
267+
kv_cache_dtype=self.cache_config.cache_dtype,
269268
is_driver_worker=True,
270269
)
271270

0 commit comments

Comments
 (0)