@@ -217,12 +217,96 @@ static int accelerator_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type,
217217 return 0 ;
218218}
219219
220+ static int accelerator_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
221+ int * dev_id )
222+ {
223+ #if OPAL_CUDA_VMM_SUPPORT
224+ static int device_count = -1 ;
225+ static int mpool_supported = -1 ;
226+ CUresult result ;
227+ CUmemoryPool mpool ;
228+ CUmemAccess_flags flags ;
229+ CUmemLocation location ;
230+
231+ if (mpool_supported <= 0 ) {
232+ if (mpool_supported == -1 ) {
233+ if (device_count == -1 ) {
234+ result = cuDeviceGetCount (& device_count );
235+ if (result != CUDA_SUCCESS || (0 == device_count )) {
236+ mpool_supported = 0 ; /* never check again */
237+ device_count = 0 ;
238+ return 0 ;
239+ }
240+ }
241+
242+ /* assume uniformity of devices */
243+ result = cuDeviceGetAttribute (& mpool_supported ,
244+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
245+ if (result != CUDA_SUCCESS ) {
246+ mpool_supported = 0 ;
247+ }
248+ }
249+ if (0 == mpool_supported ) {
250+ return 0 ;
251+ }
252+ }
253+
254+ result = cuPointerGetAttribute (& mpool , CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
255+ dbuf );
256+ if (CUDA_SUCCESS != result ) {
257+ return 0 ;
258+ }
259+
260+ /* check if device has access */
261+ for (int i = 0 ; i < device_count ; i ++ ) {
262+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
263+ location .id = i ;
264+ result = cuMemPoolGetAccess (& flags , mpool , & location );
265+ if ((CUDA_SUCCESS == result ) &&
266+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
267+ * mem_type = CU_MEMORYTYPE_DEVICE ;
268+ * dev_id = i ;
269+ return 1 ;
270+ }
271+ }
272+
273+ /* host must have access as device access possibility is exhausted */
274+ * mem_type = CU_MEMORYTYPE_HOST ;
275+ * dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
276+ return 0 ;
277+ #endif
278+
279+ return 0 ;
280+ }
281+
282+ static int accelerator_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
283+ {
284+ CUresult result ;
285+ unsigned int flags ;
286+ int active ;
287+
288+ result = cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
289+ if (CUDA_SUCCESS != result ) {
290+ return OPAL_ERROR ;
291+ }
292+
293+ if (active ) {
294+ result = cuDevicePrimaryCtxRetain (pctx , dev_id );
295+ return OPAL_SUCCESS ;
296+ }
297+
298+ return OPAL_ERROR ;
299+ }
300+
220301static int accelerator_cuda_check_addr (const void * addr , int * dev_id , uint64_t * flags )
221302{
222303 CUresult result ;
223304 int is_vmm = 0 ;
305+ int is_mpool_ptr = 0 ;
224306 int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
307+ int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
225308 CUmemorytype vmm_mem_type = 0 ;
309+ CUmemorytype mpool_mem_type = 0 ;
226310 CUmemorytype mem_type = 0 ;
227311 CUdeviceptr dbuf = (CUdeviceptr ) addr ;
228312 CUcontext ctx = NULL , mem_ctx = NULL ;
@@ -235,6 +319,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
235319 * flags = 0 ;
236320
237321 is_vmm = accelerator_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
322+ is_mpool_ptr = accelerator_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
238323
239324#if OPAL_CUDA_GET_ATTRIBUTES
240325 uint32_t is_managed = 0 ;
@@ -268,6 +353,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
268353 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
269354 mem_type = CU_MEMORYTYPE_DEVICE ;
270355 * dev_id = vmm_dev_id ;
356+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
357+ mem_type = CU_MEMORYTYPE_DEVICE ;
358+ * dev_id = mpool_dev_id ;
271359 } else {
272360 /* Host memory, nothing to do here */
273361 return 0 ;
@@ -278,6 +366,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278366 } else {
279367 if (is_vmm ) {
280368 * dev_id = vmm_dev_id ;
369+ } else if (is_mpool_ptr ) {
370+ * dev_id = mpool_dev_id ;
281371 } else {
282372 /* query the device from the context */
283373 * dev_id = accelerator_cuda_get_device_id (mem_ctx );
@@ -296,13 +386,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
296386 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
297387 mem_type = CU_MEMORYTYPE_DEVICE ;
298388 * dev_id = vmm_dev_id ;
389+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
390+ mem_type = CU_MEMORYTYPE_DEVICE ;
391+ * dev_id = mpool_dev_id ;
299392 } else {
300393 /* Host memory, nothing to do here */
301394 return 0 ;
302395 }
303396 } else {
304397 if (is_vmm ) {
305398 * dev_id = vmm_dev_id ;
399+ } else if (is_mpool_ptr ) {
400+ * dev_id = mpool_dev_id ;
306401 } else {
307402 result = cuPointerGetAttribute (& mem_ctx ,
308403 CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
@@ -336,14 +431,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
336431 return OPAL_ERROR ;
337432 }
338433#endif /* OPAL_CUDA_GET_ATTRIBUTES */
339- if (is_vmm ) {
340- /* This function is expected to set context if pointer is device
341- * accessible but VMM allocations have NULL context associated
342- * which cannot be set against the calling thread */
343- opal_output (0 ,
344- "CUDA: unable to set context with the given pointer"
345- "ptr=%p aborting..." , addr );
346- return OPAL_ERROR ;
434+ if (is_vmm || is_mpool_ptr ) {
435+ if (OPAL_SUCCESS ==
436+ accelerator_cuda_get_primary_context (
437+ is_vmm ? vmm_dev_id : mpool_dev_id , & mem_ctx )) {
438+ /* As VMM/mempool allocations have no context associated
439+ * with them, check if device primary context can be set */
440+ } else {
441+ opal_output (0 ,
442+ "CUDA: unable to set ctx with the given pointer"
443+ "ptr=%p aborting..." , addr );
444+ return OPAL_ERROR ;
445+ }
347446 }
348447
349448 result = cuCtxSetCurrent (mem_ctx );
0 commit comments