|
23 | 23 | from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
24 | 24 | MemorySnapshot, PlaceholderModule, StoreBoolean,
|
25 | 25 | bind_kv_cache, common_broadcastable_dtype,
|
26 |
| - deprecate_kwargs, get_open_port, get_tcp_uri, |
27 |
| - is_lossless_cast, join_host_port, make_zmq_path, |
28 |
| - make_zmq_socket, memory_profiling, |
| 26 | + current_stream, deprecate_kwargs, get_open_port, |
| 27 | + get_tcp_uri, is_lossless_cast, join_host_port, |
| 28 | + make_zmq_path, make_zmq_socket, memory_profiling, |
29 | 29 | merge_async_iterators, sha256, split_host_port,
|
30 | 30 | split_zmq_path, supports_kw, swap_dict_values)
|
31 | 31 |
|
@@ -957,3 +957,41 @@ def test_convert_ids_list_to_tokens():
|
957 | 957 | ]
|
958 | 958 | tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
|
959 | 959 | assert tokens == ['Hello', ',', ' world', '!']
|
| 960 | + |
| 961 | + |
| 962 | +def test_current_stream_multithread(): |
| 963 | + import threading |
| 964 | + if not torch.cuda.is_available(): |
| 965 | + pytest.skip("CUDA not available") |
| 966 | + |
| 967 | + main_default_stream = torch.cuda.current_stream() |
| 968 | + child_stream = torch.cuda.Stream() |
| 969 | + |
| 970 | + thread_stream_ready = threading.Event() |
| 971 | + thread_can_exit = threading.Event() |
| 972 | + |
| 973 | + def child_thread_func(): |
| 974 | + with torch.cuda.stream(child_stream): |
| 975 | + thread_stream_ready.set() |
| 976 | + thread_can_exit.wait(timeout=10) |
| 977 | + |
| 978 | + child_thread = threading.Thread(target=child_thread_func) |
| 979 | + child_thread.start() |
| 980 | + |
| 981 | + try: |
| 982 | + assert thread_stream_ready.wait( |
| 983 | + timeout=5), "Child thread failed to enter stream context in time" |
| 984 | + |
| 985 | + main_current_stream = current_stream() |
| 986 | + |
| 987 | + assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread" |
| 988 | + assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream" |
| 989 | + |
| 990 | + # Notify child thread it can exit |
| 991 | + thread_can_exit.set() |
| 992 | + |
| 993 | + finally: |
| 994 | + # Ensure child thread exits properly |
| 995 | + child_thread.join(timeout=5) |
| 996 | + if child_thread.is_alive(): |
| 997 | + pytest.fail("Child thread failed to exit properly") |
0 commit comments