@@ -29,6 +29,11 @@ typedef struct ucx_iovec {
2929 size_t len ;
3030} ucx_iovec_t ;
3131
32+ OBJ_CLASS_INSTANCE (thread_local_info_t , opal_list_item_t , NULL , NULL );
33+
34+ __thread thread_local_info_t * my_thread_info = NULL ;
35+ pthread_key_t my_thread_key = {0 };
36+
3237static inline int check_sync_state (ompi_osc_ucx_module_t * module , int target ,
3338 bool is_req_ops ) {
3439 if (is_req_ops == false) {
@@ -367,19 +372,42 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
367372 int target , ptrdiff_t target_disp , int target_count ,
368373 struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
369374 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
370- ucp_ep_h ep = OSC_UCX_GET_EP ( module -> comm , target ) ;
375+ ucp_ep_h ep ;
371376 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
372377 ucp_rkey_h rkey ;
373378 bool is_origin_contig = false, is_target_contig = false;
374379 ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
375380 ucs_status_t status ;
381+ pthread_t tid = pthread_self ();
376382 int ret = OMPI_SUCCESS ;
377383
378384 ret = check_sync_state (module , target , false);
379385 if (ret != OMPI_SUCCESS ) {
380386 return ret ;
381387 }
382388
389+ if (pthread_equal (tid , mca_osc_ucx_component .main_tid )) {
390+ ep = OSC_UCX_GET_EP (module -> comm , target );
391+ rkey = (module -> win_info_array [target ]).rkey ;
392+ } else {
393+ thread_local_info_t * curr_thread_info ;
394+ if ((curr_thread_info = pthread_getspecific (my_thread_key )) == NULL ) {
395+ ret = opal_common_ucx_create_local_worker (mca_osc_ucx_component .ucp_context ,
396+ ompi_comm_size (module -> comm ),
397+ mca_osc_ucx_component .worker_addr_buf ,
398+ mca_osc_ucx_component .worker_addr_disps ,
399+ mca_osc_ucx_component .mem_addr_buf ,
400+ mca_osc_ucx_component .mem_addr_disps );
401+ if (ret != OMPI_SUCCESS ) {
402+ return ret ;
403+ }
404+ }
405+
406+ curr_thread_info = pthread_getspecific (my_thread_key );
407+ rkey = curr_thread_info -> rkeys [target ];
408+ ep = curr_thread_info -> eps [target ];
409+ }
410+
383411 if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
384412 status = get_dynamic_win_info (remote_addr , module , ep , target );
385413 if (status != UCS_OK ) {
@@ -393,8 +421,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
393421 return OMPI_SUCCESS ;
394422 }
395423
396- rkey = (module -> win_info_array [target ]).rkey ;
397-
398424 ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
399425 ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
400426
@@ -427,19 +453,42 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
427453 int target , ptrdiff_t target_disp , int target_count ,
428454 struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
429455 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
430- ucp_ep_h ep = OSC_UCX_GET_EP ( module -> comm , target ) ;
456+ ucp_ep_h ep ;
431457 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
432458 ucp_rkey_h rkey ;
433459 ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
434460 bool is_origin_contig = false, is_target_contig = false;
435461 ucs_status_t status ;
462+ pthread_t tid = pthread_self ();
436463 int ret = OMPI_SUCCESS ;
437464
438465 ret = check_sync_state (module , target , false);
439466 if (ret != OMPI_SUCCESS ) {
440467 return ret ;
441468 }
442469
470+ if (pthread_equal (tid , mca_osc_ucx_component .main_tid )) {
471+ ep = OSC_UCX_GET_EP (module -> comm , target );
472+ rkey = (module -> win_info_array [target ]).rkey ;
473+ } else {
474+ thread_local_info_t * curr_thread_info ;
475+ if ((curr_thread_info = pthread_getspecific (my_thread_key )) == NULL ) {
476+ ret = opal_common_ucx_create_local_worker (mca_osc_ucx_component .ucp_context ,
477+ ompi_comm_size (module -> comm ),
478+ mca_osc_ucx_component .worker_addr_buf ,
479+ mca_osc_ucx_component .worker_addr_disps ,
480+ mca_osc_ucx_component .mem_addr_buf ,
481+ mca_osc_ucx_component .mem_addr_disps );
482+ if (ret != OMPI_SUCCESS ) {
483+ return ret ;
484+ }
485+ }
486+
487+ curr_thread_info = pthread_getspecific (my_thread_key );
488+ rkey = curr_thread_info -> rkeys [target ];
489+ ep = curr_thread_info -> eps [target ];
490+ }
491+
443492 if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
444493 status = get_dynamic_win_info (remote_addr , module , ep , target );
445494 if (status != UCS_OK ) {
@@ -453,8 +502,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
453502 return OMPI_SUCCESS ;
454503 }
455504
456- rkey = (module -> win_info_array [target ]).rkey ;
457-
458505 ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
459506 ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
460507
0 commit comments