Skip to content

Commit 464dd98

Browse files
authored
Fix num_gpus when TP > 1 (#1852)
1 parent c07a442 commit 464dd98

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,16 @@ def _init_engine(self, *args,
301301
elif self.worker_use_ray:
302302
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
303303
else:
304-
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
304+
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
305+
# order of the arguments.
306+
cache_config = args[1]
307+
parallel_config = args[2]
308+
if parallel_config.tensor_parallel_size == 1:
309+
num_gpus = cache_config.gpu_memory_utilization
310+
else:
311+
num_gpus = 1
312+
engine_class = ray.remote(num_gpus=num_gpus)(
313+
self._engine_class).remote
305314
return engine_class(*args, **kwargs)
306315

307316
async def engine_step(self) -> bool:

vllm/engine/llm_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,13 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
159159
for bundle in placement_group.bundle_specs:
160160
if not bundle.get("GPU", 0):
161161
continue
162+
if self.parallel_config.tensor_parallel_size == 1:
163+
num_gpus = self.cache_config.gpu_memory_utilization
164+
else:
165+
num_gpus = 1
162166
worker = ray.remote(
163167
num_cpus=0,
164-
num_gpus=self.cache_config.gpu_memory_utilization,
168+
num_gpus=num_gpus,
165169
scheduling_strategy=PlacementGroupSchedulingStrategy(
166170
placement_group=placement_group,
167171
placement_group_capture_child_tasks=True),

0 commit comments

Comments
 (0)