|
11 | 11 | from vllm.model_executor.model_loader import get_model
|
12 | 12 | from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
13 | 13 | from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
| 14 | +from vllm.v1.worker.utils import CpuGpuBuffer |
14 | 15 |
|
15 | 16 | if TYPE_CHECKING:
|
16 | 17 | from vllm.v1.core.sched.output import SchedulerOutput
|
|
21 | 22 | class CPUModelRunner(GPUModelRunner):
|
22 | 23 |
|
23 | 24 | def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
24 |
| - super().__init__(vllm_config, device) |
| 25 | + with _torch_cuda_wrapper(): |
| 26 | + super().__init__(vllm_config, device) |
25 | 27 |
|
26 | 28 | assert device == torch.device("cpu")
|
27 | 29 | assert self.speculative_config is None, "spec decode is not supported."
|
@@ -71,8 +73,8 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
|
71 | 73 | setattr(obj, device_attr_name, cpu_tensor)
|
72 | 74 |
|
73 | 75 | for k, v in vars(self).items():
|
74 |
| - if k.endswith("_cpu") and isinstance(v, torch.Tensor): |
75 |
| - replace_tensor(self, k, k[:-4]) |
| 76 | + if isinstance(v, CpuGpuBuffer): |
| 77 | + v.gpu = v.cpu |
76 | 78 |
|
77 | 79 | for k, v in vars(self.input_batch).items():
|
78 | 80 | if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
@@ -108,6 +110,26 @@ def _init_device_properties(self) -> None:
|
108 | 110 | def _sync_device(self) -> None:
|
109 | 111 | pass
|
110 | 112 |
|
| 113 | + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: |
| 114 | + return sampled_token_ids.tolist() |
| 115 | + |
| 116 | + |
| 117 | +@contextmanager |
| 118 | +def _torch_cuda_wrapper(): |
| 119 | + |
| 120 | + class _EventPlaceholder: |
| 121 | + |
| 122 | + def __init__(self, *args, **kwargs) -> None: |
| 123 | + self.record = lambda: None |
| 124 | + self.synchronize = lambda: None |
| 125 | + |
| 126 | + try: |
| 127 | + cuda_event = torch.cuda.Event |
| 128 | + torch.cuda.Event = _EventPlaceholder |
| 129 | + yield |
| 130 | + finally: |
| 131 | + torch.cuda.Event = cuda_event |
| 132 | + |
111 | 133 |
|
112 | 134 | @contextmanager
|
113 | 135 | def _set_global_compilation_settings(config: VllmConfig):
|
|
0 commit comments