Skip to content

Commit 4042d19

Browse files
authored
fix "tansformers_module" ModuleNotFoundError when load model with trust_remote_code=True (#871)
1 parent 1117aa1 commit 4042d19

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

vllm/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
153153
placement_group=placement_group,
154154
placement_group_capture_child_tasks=True),
155155
**ray_remote_kwargs,
156-
)(RayWorker).remote()
156+
)(RayWorker).remote(self.model_config.trust_remote_code)
157157
self.workers.append(worker)
158158

159159
# Initialize torch distributed process group for the workers.

vllm/engine/ray_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ class RayWorker(TorchDistributedWorker):
1111
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
1212
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
1313

14-
def __init__(self) -> None:
14+
def __init__(self, init_cached_hf_modules=False) -> None:
15+
if init_cached_hf_modules:
16+
# pylint: disable=import-outside-toplevel
17+
from transformers.dynamic_module_utils import init_hf_modules
18+
init_hf_modules()
1519
self.worker = None
1620

1721
def init_worker(self, worker_init_fn):

0 commit comments

Comments
 (0)