diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 4ba38216a..b459d867d 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -7,7 +7,9 @@ #include #include +#include #include +#include #include #include @@ -18,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -611,6 +614,61 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + // Template version with named parameters + template + [[nodiscard]] std::enable_if_t< + detail::contains_type::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> + tagSend(Options&&... opts) + { + // Default values for optional parameters + std::shared_ptr endpoint = nullptr; + std::optional> requestData; + bool enablePythonFuture = false; + RequestCallbackUserFunction callbackFunction = nullptr; + RequestCallbackUserData callbackData = nullptr; + + // Helper to set parameters + auto setParam = [&](auto&& param) { + using ParamType = std::decay_t; + if constexpr (std::is_same_v) { + endpoint = std::move(param.value); + } else if constexpr (std::is_same_v) { + requestData.emplace(std::move(param.value)); + } else if constexpr (std::is_same_v) { + enablePythonFuture = param.value; + } else if constexpr (std::is_same_v) { + callbackFunction = param.value; + } else if constexpr (std::is_same_v) { + callbackData = param.value; + } + }; + + // Set all parameters + (setParam(std::forward(opts)), ...); + + // Ensure required parameters are present + if (!endpoint || !requestData) { + throw std::runtime_error("Missing required parameters for tagSend"); + } + + // Create the request with the collected parameters and register it + return registerInflightRequest(createRequestTag(std::forward(opts)...)); + } + + // Overload for template-style parameters (deprecated) + [[nodiscard]] std::shared_ptr tagSend( + request_tag_params::EndpointParam&& endpointParam, + request_tag_params::RequestDataParam&& requestDataParam, + request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam = + request_tag_params::EnablePythonFutureParam{false}, + request_tag_params::CallbackFunctionParam&& callbackFunctionParam = + request_tag_params::CallbackFunctionParam{nullptr}, + request_tag_params::CallbackDataParam&& callbackDataParam = + request_tag_params::CallbackDataParam{nullptr}); + /** * @brief Enqueue a tag receive operation. * @@ -644,6 +702,61 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + // Template version with named parameters + template + [[nodiscard]] std::enable_if_t< + detail::contains_type::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> + tagRecv(Options&&... opts) + { + // Default values for optional parameters + std::shared_ptr endpoint = nullptr; + std::optional> requestData; + bool enablePythonFuture = false; + RequestCallbackUserFunction callbackFunction = nullptr; + RequestCallbackUserData callbackData = nullptr; + + // Helper to set parameters + auto setParam = [&](auto&& param) { + using ParamType = std::decay_t; + if constexpr (std::is_same_v) { + endpoint = std::move(param.value); + } else if constexpr (std::is_same_v) { + requestData.emplace(std::move(param.value)); + } else if constexpr (std::is_same_v) { + enablePythonFuture = param.value; + } else if constexpr (std::is_same_v) { + callbackFunction = param.value; + } else if constexpr (std::is_same_v) { + callbackData = param.value; + } + }; + + // Set all parameters + (setParam(std::forward(opts)), ...); + + // Ensure required parameters are present + if (!endpoint || !requestData) { + throw std::runtime_error("Missing required parameters for tagRecv"); + } + + // Create the request with the collected parameters and register it + return registerInflightRequest(createRequestTag(std::forward(opts)...)); + } + + // Overload for template-style parameters (deprecated) + [[nodiscard]] std::shared_ptr tagRecv( + request_tag_params::EndpointParam&& endpointParam, + request_tag_params::RequestDataParam&& requestDataParam, + request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam = + request_tag_params::EnablePythonFutureParam{false}, + request_tag_params::CallbackFunctionParam&& callbackFunctionParam = + request_tag_params::CallbackFunctionParam{nullptr}, + request_tag_params::CallbackDataParam&& callbackDataParam = + request_tag_params::CallbackDataParam{nullptr}); + /** * @brief Enqueue a multi-buffer tag send operation. * diff --git a/cpp/include/ucxx/request_tag.h b/cpp/include/ucxx/request_tag.h index 6ac5fe0e0..6c8eed760 100644 --- a/cpp/include/ucxx/request_tag.h +++ b/cpp/include/ucxx/request_tag.h @@ -1,10 +1,11 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once #include #include +#include #include #include @@ -53,37 +54,13 @@ class RequestTag : public Request { * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestTag(std::shared_ptr endpointOrWorker, - const std::variant requestData, - const std::string operationName, + const std::variant& requestData, + const std::string& operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); - public: - /** - * @brief Constructor for `std::shared_ptr`. - * - * The constructor for a `std::shared_ptr` object, creating a send or - * receive tag request, returning a pointer to a request object that can be later awaited - * and checked for errors. This is a non-blocking operation, and the status of the - * transfer must be verified from the resulting request object before the data can be - * released (for a send operation) or consumed (for a receive operation). - * - * @throws ucxx::Error if send is `true` and `endpointOrWorker` is not a - * `std::shared_ptr`. - * - * @param[in] endpointOrWorker the parent component, which may either be a - * `std::shared_ptr` or - * `std::shared_ptr`. - * @param[in] requestData container of the specified message type, including all - * type-specific data. - * @param[in] enablePythonFuture whether a python future should be created and - * subsequently notified. - * @param[in] callbackFunction user-defined callback function to call upon completion. - * @param[in] callbackData user-defined data to pass to the `callbackFunction`. - * - * @returns The `shared_ptr` object - */ + // Friend declarations for both createRequestTag functions friend std::shared_ptr createRequestTag( std::shared_ptr endpointOrWorker, const std::variant requestData, @@ -91,6 +68,16 @@ class RequestTag : public Request { RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); + // Friend the templated version + template + friend std::enable_if_t< + detail::contains_type::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> + createRequestTag(Options&&... opts); + + public: virtual void populateDelayedSubmission(); /** @@ -160,4 +147,61 @@ class RequestTag : public Request { void callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* info); }; +// Implementation of the templated createRequestTag function +template +std::enable_if_t::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> +createRequestTag(Options&&... opts) +{ + // Default values for optional parameters + std::shared_ptr endpointOrWorker; + std::optional> requestData; + bool enablePythonFuture = false; + RequestCallbackUserFunction callbackFunction = nullptr; + RequestCallbackUserData callbackData = nullptr; + std::string operationName = "tagOp"; + + // Helper to set parameters + auto setParam = [&](auto&& param) { + using ParamType = std::decay_t; + if constexpr (std::is_same_v) { + endpointOrWorker = std::move(param.value); + } else if constexpr (std::is_same_v) { + requestData.emplace(std::move(param.value)); + } else if constexpr (std::is_same_v) { + enablePythonFuture = param.value; + } else if constexpr (std::is_same_v) { + callbackFunction = param.value; + } else if constexpr (std::is_same_v) { + callbackData = param.value; + } else if constexpr (std::is_same_v) { + operationName = std::move(param.value); + } + }; + + // Set all parameters + (setParam(std::forward(opts)), ...); + + // Ensure required parameters are present + if (!endpointOrWorker || !requestData) { + throw std::runtime_error("Missing required parameters for RequestTag creation"); + } + + // Create the RequestTag with the collected parameters + auto req = std::shared_ptr(new RequestTag(std::move(endpointOrWorker), + std::move(*requestData), + std::move(operationName), + enablePythonFuture, + callbackFunction, + callbackData)); + + // Register delayed submission + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + } // namespace ucxx diff --git a/cpp/include/ucxx/request_tag_params.h b/cpp/include/ucxx/request_tag_params.h new file mode 100644 index 000000000..2bf6a8931 --- /dev/null +++ b/cpp/include/ucxx/request_tag_params.h @@ -0,0 +1,134 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace ucxx { + +namespace detail { +// Helper to remove const, volatile and reference qualifiers (C++17 compatible version of +// remove_cvref) +template +using remove_cvref = std::remove_reference_t>; + +// Type traits to detect parameter types +template +struct is_endpoint_param : std::false_type {}; + +template +struct is_request_data_param : std::false_type {}; + +// Helper to check if parameter pack contains a type +template +struct contains_type : std::disjunction>...> {}; + +// Helper to ensure no duplicate parameter types +template +struct has_unique_types; + +template <> +struct has_unique_types<> : std::true_type {}; + +template +struct has_unique_types { + static constexpr bool value = + (!std::disjunction...>::value) && has_unique_types::value; +}; +} // namespace detail + +/** + * @brief Parameter tag types for RequestTag creation + * + * These types provide a type-safe way to pass named parameters to createRequestTag. + * Each type wraps a specific parameter and provides a clear name at the call site. + */ +namespace request_tag_params { + +/** + * @brief Parameter wrapper for endpoint or worker component + */ +struct EndpointParam { + std::shared_ptr value; + explicit EndpointParam(std::shared_ptr ep) : value(std::move(ep)) {} +}; + +/** + * @brief Parameter wrapper for request data (TagSend or TagReceive) + */ +struct RequestDataParam { + std::variant value; + + explicit RequestDataParam(const data::TagSend& send) : value(send) {} + + explicit RequestDataParam(const data::TagReceive& recv) : value(recv) {} + + explicit RequestDataParam(const std::variant& data) : value(data) + { + } +}; + +/** + * @brief Parameter wrapper for Python future enablement + */ +struct EnablePythonFutureParam { + bool value; + explicit EnablePythonFutureParam(bool enable) : value(enable) {} +}; + +/** + * @brief Parameter wrapper for callback function + */ +struct CallbackFunctionParam { + RequestCallbackUserFunction value; + explicit CallbackFunctionParam(RequestCallbackUserFunction fn) : value(fn) {} +}; + +/** + * @brief Parameter wrapper for callback data + */ +struct CallbackDataParam { + RequestCallbackUserData value; + explicit CallbackDataParam(RequestCallbackUserData data) : value(data) {} +}; + +/** + * @brief Parameter wrapper for operation name + */ +struct OperationNameParam { + std::string value; + explicit OperationNameParam(std::string name) : value(std::move(name)) {} +}; + +} // namespace request_tag_params + +// Complete the type trait specializations after parameter types are defined +namespace detail { +template <> +struct is_endpoint_param : std::true_type {}; + +template <> +struct is_request_data_param : std::true_type {}; +} // namespace detail + +// Forward declarations +class RequestTag; + +// Forward declare the factory function to be friended +template +std::enable_if_t::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> +createRequestTag(Options&&... opts); + +} // namespace ucxx \ No newline at end of file diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index d8376afa1..e046e26e8 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -561,14 +561,27 @@ std::shared_ptr Endpoint::tagSend(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { - auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTag(endpoint, - data::TagSend(buffer, length, tag), + return registerInflightRequest(createRequestTag(shared_from_this(), + data::TagSend{buffer, length, tag}, enablePythonFuture, callbackFunction, callbackData)); } +std::shared_ptr Endpoint::tagSend( + request_tag_params::EndpointParam&& endpointParam, + request_tag_params::RequestDataParam&& requestDataParam, + request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam, + request_tag_params::CallbackFunctionParam&& callbackFunctionParam, + request_tag_params::CallbackDataParam&& callbackDataParam) +{ + return registerInflightRequest(createRequestTag(endpointParam.value, + requestDataParam.value, + enablePythonFutureParam.value, + callbackFunctionParam.value, + callbackDataParam.value)); +} + std::shared_ptr Endpoint::tagRecv(void* buffer, size_t length, Tag tag, @@ -577,14 +590,27 @@ std::shared_ptr Endpoint::tagRecv(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { - auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTag(endpoint, - data::TagReceive(buffer, length, tag, tagMask), + return registerInflightRequest(createRequestTag(shared_from_this(), + data::TagReceive{buffer, length, tag, tagMask}, enablePythonFuture, callbackFunction, callbackData)); } +std::shared_ptr Endpoint::tagRecv( + request_tag_params::EndpointParam&& endpointParam, + request_tag_params::RequestDataParam&& requestDataParam, + request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam, + request_tag_params::CallbackFunctionParam&& callbackFunctionParam, + request_tag_params::CallbackDataParam&& callbackDataParam) +{ + return registerInflightRequest(createRequestTag(endpointParam.value, + requestDataParam.value, + enablePythonFutureParam.value, + callbackFunctionParam.value, + callbackDataParam.value)); +} + std::shared_ptr Endpoint::tagMultiSend(const std::vector& buffer, const std::vector& size, const std::vector& isCUDA, diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index 9655c441a..74562af5e 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -14,6 +14,14 @@ namespace ucxx { +// Forward declare createRequestTag to return std::shared_ptr +template +std::enable_if_t::value && + detail::contains_type::value && + detail::has_unique_types...>::value, + std::shared_ptr> +createRequestTag(Options&&... opts); + std::shared_ptr createRequestTag( std::shared_ptr endpointOrWorker, const std::variant requestData, @@ -54,8 +62,8 @@ std::shared_ptr createRequestTag( } RequestTag::RequestTag(std::shared_ptr endpointOrWorker, - const std::variant requestData, - const std::string operationName, + const std::variant& requestData, + const std::string& operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 5349bdc86..e8385c49f 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -301,6 +301,35 @@ TEST_P(RequestTest, ProgressTag) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressTagTemplated) +{ + allocate(); + + auto endpoint = std::dynamic_pointer_cast(_ep); + + // Submit and wait for transfers to complete using the new templated API + std::vector> requests; + + // Send using named parameters in arbitrary order + requests.push_back(_ep->tagSend(ucxx::request_tag_params::RequestDataParam{ucxx::data::TagSend{ + _sendPtr[0], _messageSize, ucxx::Tag{0}}}, + ucxx::request_tag_params::EndpointParam{endpoint}, + ucxx::request_tag_params::EnablePythonFutureParam{false})); + + // Receive using named parameters in different order + requests.push_back(_ep->tagRecv(ucxx::request_tag_params::EnablePythonFutureParam{false}, + ucxx::request_tag_params::EndpointParam{endpoint}, + ucxx::request_tag_params::RequestDataParam{ucxx::data::TagReceive{ + _recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull}})); + + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressTagMulti) { if (_progressMode == ProgressMode::Wait) {