1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
+ import os
5
+ import tempfile
6
+ import textwrap
4
7
import time
5
- import uuid
6
- from collections import defaultdict
7
- from typing import Optional
8
8
from unittest .mock import patch
9
9
10
10
import pytest
11
+ import ray
11
12
12
13
from vllm import LLM
13
14
from vllm .config import KVTransferConfig
14
15
from vllm .distributed .kv_transfer .kv_connector .v1 .nixl_connector import (
15
16
KVConnectorRole , NixlAgentMetadata , NixlConnector , NixlConnectorMetadata ,
16
17
NixlConnectorWorker )
17
18
from vllm .forward_context import ForwardContext
19
+ from vllm .mocks .mock_nixl_connector import FakeNixlWrapper
18
20
from vllm .sampling_params import SamplingParams
19
21
20
22
from .utils import create_request , create_scheduler , create_vllm_config
21
23
22
24
25
+ def _make_stub_pkg () -> str :
26
+ """Return a directory that makes
27
+ `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
28
+ td = tempfile .mkdtemp ()
29
+ pkg_root = os .path .join (td , "nixl" , "_api" )
30
+ os .makedirs (pkg_root , exist_ok = True )
31
+
32
+ stub = textwrap .dedent ("""\
33
+ # Forward the real FakeNixlWrapper that the driver already defined.
34
+ print("In fake package")
35
+ from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
36
+ """ )
37
+ with open (os .path .join (pkg_root , "__init__.py" ), "w" ) as f :
38
+ f .write (stub )
39
+
40
+ # touch parent package
41
+ open (os .path .join (td , "nixl" , "__init__.py" ), "w" ).close ()
42
+ return td
43
+
44
+
23
45
def test_basic_interface ():
24
46
"""Unit test for basic NixlConnector interface functionality."""
25
47
@@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
87
109
assert len (scheduler_output .scheduled_new_reqs ) == 1
88
110
89
111
90
- class FakeNixlWrapper :
91
- """Mock implementation of NixlWrapper for testing.
92
-
93
- We don't inherit from nixl._api.nixl_agent because nixl may not be
94
- installed.
95
- """
96
-
97
- AGENT_METADATA = b"fake_agent_metadata"
98
- REMOTE_AGENT_NAME = "remote_agent"
99
-
100
- def __init__ (self , agent_name : str , * args , ** kwargs ):
101
- self ._cycles_before_xfer_done = 0
102
- self ._check_xfer_state_cycles : defaultdict [int , int ] = defaultdict (
103
- lambda : 0 )
104
-
105
- def get_reg_descs (self , caches_data , memory_type : str ) -> list :
106
- return [str (uuid .uuid4 ()) for _ in caches_data ]
107
-
108
- def register_memory (self , descs ) -> None :
109
- pass
110
-
111
- def get_xfer_descs (self , blocks_data , memory_type : str ) -> list :
112
- return [str (uuid .uuid4 ()) for _ in blocks_data ]
113
-
114
- def prep_xfer_dlist (self , agent_name : str , descs : list ) -> int :
115
- return uuid .uuid4 ().int
116
-
117
- def get_agent_metadata (self ) -> bytes :
118
- return self .AGENT_METADATA
119
-
120
- def add_remote_agent (self , agent_metadata : bytes ) -> str :
121
- return self .REMOTE_AGENT_NAME
122
-
123
- def get_new_notifs (self ) -> dict [str , list [bytes ]]:
124
- # Used to collect done_sending, which we don't test yet.
125
- return {}
126
-
127
- def check_xfer_state (self , handle : int ) -> str :
128
- if self ._check_xfer_state_cycles [
129
- handle ] >= self ._cycles_before_xfer_done :
130
- return "DONE"
131
- self ._check_xfer_state_cycles [handle ] += 1
132
- return "PROC"
133
-
134
- def release_xfer_handle (self , handle : int ) -> None :
135
- pass
136
-
137
- def send_notif (self , agent_name : str , notif_msg : bytes ) -> None :
138
- pass
139
-
140
- def make_prepped_xfer (self ,
141
- xfer_type : str ,
142
- local_xfer_side_handle : int ,
143
- local_block_descs_ids : list [int ],
144
- remote_xfer_side_handle : int ,
145
- remote_block_descs_ids : list [int ],
146
- notif_msg : Optional [bytes ] = None ) -> int :
147
- return uuid .uuid4 ().int
148
-
149
- def transfer (self , handle : int ) -> str :
150
- return "PROC"
151
-
152
- ############################################################
153
- # Follow are for changing the behavior during testing.
154
- ############################################################
155
-
156
- def set_cycles_before_xfer_done (self , cycles : int ):
157
- """Set the number of cycles before a transfer is considered done."""
158
- self ._cycles_before_xfer_done = cycles
159
-
160
-
161
112
class FakeNixlConnectorWorker (NixlConnectorWorker ):
162
113
163
114
REMOTE_ENGINE_ID = "remote_engine"
@@ -378,10 +329,14 @@ def test_concurrent_load_kv(
378
329
raise TimeoutError ("Took too long to complete async handshake." )
379
330
380
331
332
+ # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
333
+ # we put here is important. First run ray, it will clean up the resources, then
334
+ # the rest of the tests.
335
+ @pytest .mark .parametrize ("distributed_executor_backend" , ["ray" , None ])
381
336
@patch (
382
337
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
383
338
FakeNixlWrapper )
384
- def test_abort_timeout_on_prefiller (monkeypatch ):
339
+ def test_abort_timeout_on_prefiller (monkeypatch , distributed_executor_backend ):
385
340
"""
386
341
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
387
342
-----> P
@@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
399
354
timeout = 6
400
355
monkeypatch .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
401
356
monkeypatch .setenv ("VLLM_NIXL_ABORT_REQUEST_TIMEOUT" , str (timeout ))
357
+
358
+ # Build runtime_env only if we’re using Ray
359
+ if distributed_executor_backend == "ray" :
360
+ runtime_env = {
361
+ "working_dir" : _make_stub_pkg (), # ship stub package
362
+ "env_vars" : {
363
+ "VLLM_NIXL_ABORT_REQUEST_TIMEOUT" : str (timeout ),
364
+ },
365
+ }
366
+ ray .init (runtime_env = runtime_env )
367
+
402
368
llm = LLM (
403
369
model = model_name ,
404
370
enforce_eager = True ,
405
371
gpu_memory_utilization = 0.5 ,
406
372
kv_transfer_config = kv_transfer_config ,
373
+ distributed_executor_backend = distributed_executor_backend ,
407
374
)
408
375
remote_prefill_opts = {
409
376
"do_remote_decode" : True ,
0 commit comments