diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index d8376afa1..68c34da37 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -163,14 +163,12 @@ std::shared_ptr createEndpointFromHostname(std::shared_ptr wor if (worker == nullptr || worker->getHandle() == nullptr) throw ucxx::Error("Worker not initialized"); - ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | - UCP_EP_PARAM_FIELD_ERR_HANDLER, - .flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER}; - auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port); - - params.sockaddr.addrlen = info->ai_addrlen; - params.sockaddr.addr = info->ai_addr; + const auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port); + ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + UCP_EP_PARAM_FIELD_ERR_HANDLER, + .flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER, + .sockaddr = {.addr = info->ai_addr, .addrlen = info->ai_addrlen}}; auto ep = std::shared_ptr(new Endpoint(worker, endpointErrorHandling)); ep->create(¶ms); @@ -184,11 +182,11 @@ std::shared_ptr createEndpointFromConnRequest(std::shared_ptrgetHandle() == nullptr) throw ucxx::Error("Worker not initialized"); - ucp_ep_params_t params = { - .field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER, - .flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK, - .conn_request = connRequest}; + ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + UCP_EP_PARAM_FIELD_ERR_HANDLER, + .flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK, + .conn_request = connRequest}; auto ep = std::shared_ptr(new Endpoint(listener, endpointErrorHandling)); ep->create(¶ms); @@ -204,10 +202,10 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptrgetHandle() == nullptr || address->getLength() == 0) throw ucxx::Error("Address not initialized"); - ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | - UCP_EP_PARAM_FIELD_ERR_HANDLER, - .address = address->getHandle()}; + ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + UCP_EP_PARAM_FIELD_ERR_HANDLER, + .address = address->getHandle()}; auto ep = std::shared_ptr(new Endpoint(worker, endpointErrorHandling)); ep->create(¶ms); @@ -259,9 +257,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) _handle, canceled); - ucp_request_param_t param{}; - if (_endpointErrorHandling) - param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; + const ucp_request_param_t param{.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, + .flags = UCP_EP_CLOSE_FLAG_FORCE}; auto worker = ::ucxx::getWorker(_parent); ucs_status_ptr_t status = nullptr; diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 3f6ea9a62..159b257c0 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -1577,7 +1577,7 @@ cdef void _listener_callback(ucp_conn_request_h conn_request, void *args) with g cb_data['cb_func']( ( cb_data['listener']().create_endpoint_from_conn_request( - int(conn_request), True + int(conn_request), cb_data['endpoint_error_handling'] ) if 'listener' in cb_data else int(conn_request) ), @@ -1601,6 +1601,7 @@ cdef class UCXListener(): cls, UCXWorker worker, uint16_t port, + bint endpoint_error_handling, cb_func, tuple cb_args=None, dict cb_kwargs=None, @@ -1619,6 +1620,7 @@ cdef class UCXListener(): "cb_func": cb_func, "cb_args": cb_args, "cb_kwargs": cb_kwargs, + "endpoint_error_handling": endpoint_error_handling, } if deliver_endpoint is True: cb_data["listener"] = weakref.ref(listener) diff --git a/python/ucxx/ucxx/_lib/tests/test_cancel.py b/python/ucxx/ucxx/_lib/tests/test_cancel.py index fbddd0bde..5bb5891b3 100644 --- a/python/ucxx/ucxx/_lib/tests/test_cancel.py +++ b/python/ucxx/ucxx/_lib/tests/test_cancel.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing as mp @@ -28,7 +28,7 @@ def _listener_handler(conn_request): ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/ucxx/_lib/tests/test_endpoint.py b/python/ucxx/ucxx/_lib/tests/test_endpoint.py index e3de5394d..85f3eb2cd 100644 --- a/python/ucxx/ucxx/_lib/tests/test_endpoint.py +++ b/python/ucxx/ucxx/_lib/tests/test_endpoint.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing as mp @@ -43,7 +43,7 @@ def _listener_handler(conn_request): listener_finished[0] = True listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/ucxx/_lib/tests/test_listener.py b/python/ucxx/ucxx/_lib/tests/test_listener.py index d814372ba..75054e745 100644 --- a/python/ucxx/ucxx/_lib/tests/test_listener.py +++ b/python/ucxx/ucxx/_lib/tests/test_listener.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import ucxx._lib.libucxx as ucx_api @@ -12,7 +12,7 @@ def _listener_handler(conn_request): pass listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) assert isinstance(listener.ip, str) and listener.ip diff --git a/python/ucxx/ucxx/_lib/tests/test_probe.py b/python/ucxx/ucxx/_lib/tests/test_probe.py index 004fc6232..34884be0e 100644 --- a/python/ucxx/ucxx/_lib/tests/test_probe.py +++ b/python/ucxx/ucxx/_lib/tests/test_probe.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing as mp @@ -37,7 +37,7 @@ def _listener_handler(conn_request): ) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/ucxx/_lib/tests/test_server_client.py b/python/ucxx/ucxx/_lib/tests/test_server_client.py index b36653023..0aeabf86c 100644 --- a/python/ucxx/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/ucxx/_lib/tests/test_server_client.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing as mp @@ -68,7 +68,7 @@ def _listener_handler(conn_request): ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) put_queue.put(listener.port) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 6aac43468..d14ec31f0 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import logging @@ -304,6 +304,7 @@ def create_listener( ucx_api.UCXListener.create( worker=self.worker, port=port, + endpoint_error_handling=endpoint_error_handling, cb_func=_listener_handler, cb_args=( loop, diff --git a/python/ucxx/ucxx/benchmarks/backends/ucxx_core.py b/python/ucxx/ucxx/benchmarks/backends/ucxx_core.py index 9360e4bca..f5a0e322d 100644 --- a/python/ucxx/ucxx/benchmarks/backends/ucxx_core.py +++ b/python/ucxx/ucxx/benchmarks/backends/ucxx_core.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace @@ -146,7 +146,10 @@ def _listener_handler(conn_request): ) listener = ucx_api.UCXListener.create( - worker=worker, port=self.args.port or 0, cb_func=_listener_handler + worker=worker, + port=self.args.port or 0, + endpoint_error_handling=True, + cb_func=_listener_handler, ) self.queue.put(listener.port)