Skip to content
Draft
39 changes: 18 additions & 21 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#include <memory>
Expand Down Expand Up @@ -163,14 +163,12 @@ std::shared_ptr<Endpoint> createEndpointFromHostname(std::shared_ptr<Worker> wor
if (worker == nullptr || worker->getHandle() == nullptr)
throw ucxx::Error("Worker not initialized");

ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER};
auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port);

params.sockaddr.addrlen = info->ai_addrlen;
params.sockaddr.addr = info->ai_addr;
const auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port);
ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER,
.sockaddr = {.addr = info->ai_addr, .addrlen = info->ai_addrlen}};

auto ep = std::shared_ptr<Endpoint>(new Endpoint(worker, endpointErrorHandling));
ep->create(&params);
Expand All @@ -184,11 +182,11 @@ std::shared_ptr<Endpoint> createEndpointFromConnRequest(std::shared_ptr<Listener
if (listener == nullptr || listener->getHandle() == nullptr)
throw ucxx::Error("Worker not initialized");

ucp_ep_params_t params = {
.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK,
.conn_request = connRequest};
ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK,
.conn_request = connRequest};

auto ep = std::shared_ptr<Endpoint>(new Endpoint(listener, endpointErrorHandling));
ep->create(&params);
Expand All @@ -204,10 +202,10 @@ std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Worker
if (address == nullptr || address->getHandle() == nullptr || address->getLength() == 0)
throw ucxx::Error("Address not initialized");

ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.address = address->getHandle()};
ucp_ep_params_t params{.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.address = address->getHandle()};

auto ep = std::shared_ptr<Endpoint>(new Endpoint(worker, endpointErrorHandling));
ep->create(&params);
Expand Down Expand Up @@ -259,9 +257,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts)
_handle,
canceled);

ucp_request_param_t param{};
if (_endpointErrorHandling)
param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE};
const ucp_request_param_t param{.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS,
.flags = UCP_EP_CLOSE_FLAG_FORCE};

auto worker = ::ucxx::getWorker(_parent);
ucs_status_ptr_t status = nullptr;
Expand Down
4 changes: 3 additions & 1 deletion python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ cdef void _listener_callback(ucp_conn_request_h conn_request, void *args) with g
cb_data['cb_func'](
(
cb_data['listener']().create_endpoint_from_conn_request(
int(<uintptr_t>conn_request), True
int(<uintptr_t>conn_request), cb_data['endpoint_error_handling']
) if 'listener' in cb_data else
int(<uintptr_t>conn_request)
),
Expand All @@ -1601,6 +1601,7 @@ cdef class UCXListener():
cls,
UCXWorker worker,
uint16_t port,
bint endpoint_error_handling,
cb_func,
tuple cb_args=None,
dict cb_kwargs=None,
Expand All @@ -1619,6 +1620,7 @@ cdef class UCXListener():
"cb_func": cb_func,
"cb_args": cb_args,
"cb_kwargs": cb_kwargs,
"endpoint_error_handling": endpoint_error_handling,
}
if deliver_endpoint is True:
cb_data["listener"] = weakref.ref(listener)
Expand Down
4 changes: 2 additions & 2 deletions python/ucxx/ucxx/_lib/tests/test_cancel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import multiprocessing as mp
Expand Down Expand Up @@ -28,7 +28,7 @@ def _listener_handler(conn_request):
ep[0] = listener.create_endpoint_from_conn_request(conn_request, True)

listener = ucx_api.UCXListener.create(
worker=worker, port=0, cb_func=_listener_handler
worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler
)
queue.put(listener.port)

Expand Down
4 changes: 2 additions & 2 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import multiprocessing as mp
Expand Down Expand Up @@ -43,7 +43,7 @@ def _listener_handler(conn_request):
listener_finished[0] = True

listener = ucx_api.UCXListener.create(
worker=worker, port=0, cb_func=_listener_handler
worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler
)
queue.put(listener.port)

Expand Down
4 changes: 2 additions & 2 deletions python/ucxx/ucxx/_lib/tests/test_listener.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import ucxx._lib.libucxx as ucx_api
Expand All @@ -12,7 +12,7 @@ def _listener_handler(conn_request):
pass

listener = ucx_api.UCXListener.create(
worker=worker, port=0, cb_func=_listener_handler
worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler
)

assert isinstance(listener.ip, str) and listener.ip
Expand Down
4 changes: 2 additions & 2 deletions python/ucxx/ucxx/_lib/tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import multiprocessing as mp
Expand Down Expand Up @@ -37,7 +37,7 @@ def _listener_handler(conn_request):
)

listener = ucx_api.UCXListener.create(
worker=worker, port=0, cb_func=_listener_handler
worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler
)
queue.put(listener.port)

Expand Down
4 changes: 2 additions & 2 deletions python/ucxx/ucxx/_lib/tests/test_server_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import multiprocessing as mp
Expand Down Expand Up @@ -68,7 +68,7 @@ def _listener_handler(conn_request):
ep[0] = listener.create_endpoint_from_conn_request(conn_request, True)

listener = ucx_api.UCXListener.create(
worker=worker, port=0, cb_func=_listener_handler
worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler
)
put_queue.put(listener.port)

Expand Down
3 changes: 2 additions & 1 deletion python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import logging
Expand Down Expand Up @@ -304,6 +304,7 @@ def create_listener(
ucx_api.UCXListener.create(
worker=self.worker,
port=port,
endpoint_error_handling=endpoint_error_handling,
cb_func=_listener_handler,
cb_args=(
loop,
Expand Down
7 changes: 5 additions & 2 deletions python/ucxx/ucxx/benchmarks/backends/ucxx_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
Expand Down Expand Up @@ -146,7 +146,10 @@ def _listener_handler(conn_request):
)

listener = ucx_api.UCXListener.create(
worker=worker, port=self.args.port or 0, cb_func=_listener_handler
worker=worker,
port=self.args.port or 0,
endpoint_error_handling=True,
cb_func=_listener_handler,
)
self.queue.put(listener.port)

Expand Down
Loading