Skip to content

Commit a0e827e

Browse files
authored
[BugFix] make utils.current_stream thread-safety (#21252) (#21253)
Signed-off-by: simpx <[email protected]>
1 parent a15a50f commit a0e827e

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

tests/test_utils.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
2424
MemorySnapshot, PlaceholderModule, StoreBoolean,
2525
bind_kv_cache, common_broadcastable_dtype,
26-
deprecate_kwargs, get_open_port, get_tcp_uri,
27-
is_lossless_cast, join_host_port, make_zmq_path,
28-
make_zmq_socket, memory_profiling,
26+
current_stream, deprecate_kwargs, get_open_port,
27+
get_tcp_uri, is_lossless_cast, join_host_port,
28+
make_zmq_path, make_zmq_socket, memory_profiling,
2929
merge_async_iterators, sha256, split_host_port,
3030
split_zmq_path, supports_kw, swap_dict_values)
3131

@@ -957,3 +957,41 @@ def test_convert_ids_list_to_tokens():
957957
]
958958
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
959959
assert tokens == ['Hello', ',', ' world', '!']
960+
961+
962+
def test_current_stream_multithread():
963+
import threading
964+
if not torch.cuda.is_available():
965+
pytest.skip("CUDA not available")
966+
967+
main_default_stream = torch.cuda.current_stream()
968+
child_stream = torch.cuda.Stream()
969+
970+
thread_stream_ready = threading.Event()
971+
thread_can_exit = threading.Event()
972+
973+
def child_thread_func():
974+
with torch.cuda.stream(child_stream):
975+
thread_stream_ready.set()
976+
thread_can_exit.wait(timeout=10)
977+
978+
child_thread = threading.Thread(target=child_thread_func)
979+
child_thread.start()
980+
981+
try:
982+
assert thread_stream_ready.wait(
983+
timeout=5), "Child thread failed to enter stream context in time"
984+
985+
main_current_stream = current_stream()
986+
987+
assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
988+
assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"
989+
990+
# Notify child thread it can exit
991+
thread_can_exit.set()
992+
993+
finally:
994+
# Ensure child thread exits properly
995+
child_thread.join(timeout=5)
996+
if child_thread.is_alive():
997+
pytest.fail("Child thread failed to exit properly")

vllm/utils/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,12 +1383,11 @@ def find_nccl_library() -> str:
13831383

13841384
prev_set_stream = torch.cuda.set_stream
13851385

1386-
_current_stream = None
1386+
_current_stream_tls = threading.local()
13871387

13881388

13891389
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
1390-
global _current_stream
1391-
_current_stream = stream
1390+
_current_stream_tls.value = stream
13921391
prev_set_stream(stream)
13931392

13941393

@@ -1407,16 +1406,16 @@ def current_stream() -> torch.cuda.Stream:
14071406
from C/C++ code.
14081407
"""
14091408
from vllm.platforms import current_platform
1410-
global _current_stream
1411-
if _current_stream is None:
1409+
if not hasattr(_current_stream_tls,
1410+
"value") or _current_stream_tls.value is None:
14121411
# when this function is called before any stream is set,
14131412
# we return the default stream.
14141413
# On ROCm using the default 0 stream in combination with RCCL
14151414
# is hurting performance. Therefore creating a dedicated stream
14161415
# per process
1417-
_current_stream = torch.cuda.Stream() if current_platform.is_rocm(
1418-
) else torch.cuda.current_stream()
1419-
return _current_stream
1416+
_current_stream_tls.value = torch.cuda.Stream(
1417+
) if current_platform.is_rocm() else torch.cuda.current_stream()
1418+
return _current_stream_tls.value
14201419

14211420

14221421
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:

0 commit comments

Comments
 (0)