1- # SPDX-FileCopyrightText: Copyright (c) 2023-2025 , NVIDIA CORPORATION & AFFILIATES.
1+ # SPDX-FileCopyrightText: Copyright (c) 2023-2026 , NVIDIA CORPORATION & AFFILIATES.
22# SPDX-License-Identifier: BSD-3-Clause
33
44"""
@@ -94,13 +94,12 @@ class CudaStream(Enum):
9494
9595
9696def synchronize_stream (stream : CudaStream = CudaStream .Default ):
97- import numba . cuda
97+ from ucxx . _cuda_context import synchronize_default_stream
9898
9999 if stream == CudaStream .Default :
100- numba_stream = numba . cuda . default_stream ()
100+ synchronize_default_stream ()
101101 else :
102102 raise ValueError ("Unsupported stream" )
103- numba_stream .synchronize ()
104103
105104
106105class gc_disabled :
@@ -246,11 +245,11 @@ def init_once():
246245 or ("cuda" in ucx_tls and "^cuda" not in ucx_tls )
247246 ):
248247 try :
249- import numba . cuda
250- except ImportError :
248+ from ucxx . _cuda_context import ensure_cuda_context
249+ except ImportError as e :
251250 raise ImportError (
252- "CUDA support with UCX requires Numba for context management"
253- )
251+ "CUDA support with UCX requires cuda-core for context management. "
252+ ) from e
254253
255254 cuda_visible_device = get_device_index_and_uuid (
256255 os .environ .get ("CUDA_VISIBLE_DEVICES" , "0" ).split ("," )[0 ]
@@ -261,7 +260,7 @@ def init_once():
261260 pre_existing_cuda_context .device_info , os .getpid ()
262261 )
263262
264- numba . cuda . current_context ( )
263+ ensure_cuda_context ( 0 )
265264
266265 cuda_context_created = has_cuda_context ()
267266 if (
@@ -291,7 +290,8 @@ def init_once():
291290
292291 pool_size_str = get_rmm_config ("pool-size" )
293292
294- # Find the function, `cuda_array()`, to use when allocating new CUDA arrays
293+ # Find the function, `cuda_array()`, to use when allocating new CUDA arrays.
294+ # RMM is required for CUDA array allocation at runtime (numba is only for tests).
295295 try :
296296 import rmm
297297
@@ -304,22 +304,9 @@ def device_array(n):
304304 pool_allocator = True , managed_memory = False , initial_pool_size = pool_size
305305 )
306306 except ImportError :
307- try :
308- import numba .cuda
309-
310- def numba_device_array (n ):
311- a = numba .cuda .device_array ((n ,), dtype = "u1" )
312- weakref .finalize (a , numba .cuda .current_context )
313- return a
314-
315- device_array = numba_device_array
316307
317- except ImportError :
318-
319- def device_array (n ):
320- raise RuntimeError (
321- "In order to send/recv CUDA arrays, Numba or RMM is required"
322- )
308+ def device_array (n ):
309+ raise RuntimeError ("In order to send/recv CUDA arrays, RMM is required." )
323310
324311 if pool_size_str is not None :
325312 logger .warning (
0 commit comments