Skip to content

Commit 0ed646b

Browse files
authored
[Distributed][Core] Support Py39 and Py38 for PP (#6120)
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
1 parent 1dab9bc commit 0ed646b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

vllm/executor/executor_base.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,7 @@ def __init__(
123123
multimodal_config: Optional[MultiModalConfig],
124124
speculative_config: Optional[SpeculativeConfig],
125125
) -> None:
126-
# This locks each pipeline parallel stage so multiple virtual engines
127-
# can't execute on the same stage at the same time
128-
self.pp_locks = [
129-
asyncio.Lock()
130-
for _ in range(parallel_config.pipeline_parallel_size)
131-
]
126+
self.pp_locks: Optional[List[asyncio.Lock]] = None
132127

133128
super().__init__(model_config, cache_config, parallel_config,
134129
scheduler_config, device_config, load_config,

vllm/executor/ray_gpu_executor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,15 @@ async def _driver_execute_model_async(
349349
self,
350350
execute_model_req: Optional[ExecuteModelRequest] = None
351351
) -> List[SamplerOutput]:
352+
if self.pp_locks is None:
353+
# This locks each pipeline parallel stage so multiple virtual
354+
# engines can't execute on the same stage at the same time
355+
# We create the locks here to avoid creating them in the constructor
356+
# which uses a different asyncio loop.
357+
self.pp_locks = [
358+
asyncio.Lock()
359+
for _ in range(self.parallel_config.pipeline_parallel_size)
360+
]
352361

353362
async def _run_task_with_lock(task, lock, *args, **kwargs):
354363
async with lock:

0 commit comments

Comments
 (0)