Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,52 @@ class Endpoint : public Component {
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue an active message send operation with explicit policy parameters.
*
* This overload extends `amSend()` with explicit UCX datatype/flags controls and receive
* allocation policy metadata while keeping callback behavior identical to the legacy API.
*
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the message to be sent.
* @param[in] params active message send parameters.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns Request to be subsequently checked for the completion and its state.
*/
[[nodiscard]] std::shared_ptr<Request> amSend(
const void* const buffer,
const size_t length,
const AmSendParams& params,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue an active message send operation with IOV datatype.
*
* This overload submits `UCP_DATATYPE_IOV` active message sends.
*
* @param[in] iov vector of IOV segments to be sent.
* @param[in] params active message send parameters. Datatype must be
* `UCP_DATATYPE_IOV`.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns Request to be subsequently checked for the completion and its state.
*/
[[nodiscard]] std::shared_ptr<Request> amSend(
std::vector<ucp_dt_iov_t> iov,
const AmSendParams& params,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue an active message receive operation.
*
Expand Down
5 changes: 4 additions & 1 deletion cpp/include/ucxx/internal/request_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <queue>
#include <string>
#include <unordered_map>
#include <vector>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -57,12 +58,14 @@ class RecvAmMessage {
* @param[in] request request to be later notified/delivered to user.
* @param[in] buffer buffer containing the received data
* @param[in] receiverCallback receiver callback to execute when request completes.
* @param[in] userHeader user-defined header associated with the received message.
*/
RecvAmMessage(internal::AmData* amData,
ucp_ep_h ep,
std::shared_ptr<RequestAm> request,
std::shared_ptr<Buffer> buffer,
AmReceiverCallbackType receiverCallback = AmReceiverCallbackType());
AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(),
std::vector<std::byte> userHeader = {});

/**
* @brief Set the UCP request.
Expand Down
13 changes: 12 additions & 1 deletion cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
Expand Down Expand Up @@ -224,6 +224,17 @@ class Request : public Component {
* @return The received buffer (if applicable) or `nullptr`.
*/
[[nodiscard]] virtual std::shared_ptr<Buffer> getRecvBuffer();

/**
* @brief Get the received user header.
*
* This method is used to get the user-defined header bytes for applicable derived classes
* (e.g., `RequestAm` receive operations), in all other cases this will return an empty
* string.
*
* @return The received user header (if applicable) or an empty string.
*/
[[nodiscard]] virtual std::string getRecvHeader();
};

} // namespace ucxx
11 changes: 8 additions & 3 deletions cpp/include/ucxx/request_am.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <ucp/api/ucp.h>

Expand All @@ -31,8 +33,9 @@ class RequestAm : public Request {
private:
friend class internal::RecvAmMessage;

std::string _header{}; ///< Retain copy of header for send requests as workaround for
///< https://github.com/openucx/ucx/issues/10424
std::vector<std::byte> _header{}; ///< Retain copy of header bytes for send requests as
///< workaround for
///< https://github.com/openucx/ucx/issues/10424

/**
* @brief Private constructor of `ucxx::RequestAm`.
Expand Down Expand Up @@ -161,6 +164,8 @@ class RequestAm : public Request {
const ucp_am_recv_param_t* param);

[[nodiscard]] std::shared_ptr<Buffer> getRecvBuffer() override;

[[nodiscard]] std::string getRecvHeader() override;
};

} // namespace ucxx
33 changes: 25 additions & 8 deletions cpp/include/ucxx/request_data.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once

#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>

Expand All @@ -28,11 +29,19 @@ namespace data {
*/
class AmSend {
public:
const void* const _buffer{nullptr}; ///< The raw pointer where data to be sent is stored.
const size_t _length{0}; ///< The length of the message.
const void* const _buffer{nullptr}; ///< The raw pointer where data to be sent is stored.
const size_t _length{0}; ///< Message length in bytes (contiguous datatype only).
const std::vector<ucp_dt_iov_t> _iov{}; ///< Segments for IOV datatype.
const size_t _count{0}; ///< Count passed to `ucp_am_send_nbx`: byte count
///< for contiguous, number of IOV segments for IOV.
const uint32_t _flags{UCP_AM_SEND_FLAG_REPLY}; ///< UCP AM send flags.
const ucp_datatype_t _datatype{ucp_dt_make_contig(1)}; ///< UCP datatype.
const ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_HOST}; ///< Memory type used on the operation
const AmSendMemoryTypePolicy _memoryTypePolicy{
AmSendMemoryTypePolicy::FallbackToHost}; ///< Receiver allocation policy.
const std::optional<AmReceiverCallbackInfo> _receiverCallbackInfo{
std::nullopt}; ///< Owner name and unique identifier of the receiver callback.
const std::vector<std::byte> _userHeader{}; ///< Opaque user-defined header bytes.

/**
* @brief Constructor for Active Message-specific send data.
Expand All @@ -41,14 +50,21 @@ class AmSend {
*
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] receiverCallbackInfo the owner name and unique identifier of the receiver
callback.
* @param[in] params send parameters controlling datatype/flags/policies.
*/
explicit AmSend(const decltype(_buffer) buffer,
const decltype(_length) length,
const decltype(_memoryType) memoryType = UCS_MEMORY_TYPE_HOST,
const decltype(_receiverCallbackInfo) receiverCallbackInfo = std::nullopt);
const AmSendParams& params = AmSendParams{});

/**
* @brief Constructor for Active Message-specific send data using IOV datatype.
*
* Construct an object containing Active Message-specific send data for `UCP_DATATYPE_IOV`.
*
* @param[in] iov vector of IOV segments to send.
* @param[in] params send parameters controlling datatype/flags/policies.
*/
explicit AmSend(decltype(_iov) iov, const AmSendParams& params = AmSendParams{});

AmSend() = delete;
};
Expand All @@ -62,6 +78,7 @@ class AmSend {
class AmReceive {
public:
std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer
std::vector<std::byte> _userHeader{}; ///< User-defined header bytes from the sender.

/**
* @brief Constructor for Active Message-specific receive data.
Expand Down
69 changes: 66 additions & 3 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once

#include <atomic>
#include <cstddef>
#include <cstring>
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -154,8 +159,8 @@ typedef uint64_t AmReceiverCallbackIdType;
*/
class AmReceiverCallbackInfo {
public:
const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback
const AmReceiverCallbackIdType id; ///< The unique identifier of the callback
AmReceiverCallbackOwnerType owner; ///< The owner name of the callback
AmReceiverCallbackIdType id; ///< The unique identifier of the callback

AmReceiverCallbackInfo() = delete;

Expand All @@ -168,6 +173,64 @@ class AmReceiverCallbackInfo {
AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id);
};

/**
* @brief Policy used to allocate receive buffers for Active Messages.
*
* Active Message receive allocations can be strict (error if no allocator is registered for
* sender-provided memory type) or permissive (fallback to host allocation).
*/
enum class AmSendMemoryTypePolicy {
FallbackToHost = 0, ///< If no allocator exists for memory type, fallback to host memory.
ErrorOnUnsupported, ///< If no allocator exists for memory type, fail with unsupported error.
};

/**
* @brief Parameters controlling Active Message send behavior.
*
* This object is used by the extended Active Message API to expose UCX send knobs without
* breaking existing callers.
*/
struct AmSendParams {
uint32_t flags{UCP_AM_SEND_FLAG_REPLY}; ///< UCP AM send flags.
ucp_datatype_t datatype{ucp_dt_make_contig(1)}; ///< Datatype used by `ucp_am_send_nbx`.
ucs_memory_type_t memoryType{UCS_MEMORY_TYPE_HOST}; ///< Sender memory type hint.
AmSendMemoryTypePolicy memoryTypePolicy{
AmSendMemoryTypePolicy::FallbackToHost}; ///< Receiver allocation policy.
std::optional<AmReceiverCallbackInfo> receiverCallbackInfo{
std::nullopt}; ///< Optional receiver callback metadata.
std::vector<std::byte> userHeader{}; ///< Opaque user-defined header bytes. This is serialized
///< into the AM header parameter of `ucp_am_send_nbx`,
///< which is subject to transport-level size limits. For
///< TCP, the default segment size is ~8 KiB
///< (`UCX_TCP_TX_SEG_SIZE` / `UCX_TCP_RX_SEG_SIZE`).
///< Headers that exceed the transport limit will cause a
///< fatal UCX error. Keep user headers small
///< (recommended < 4 KiB) or increase segment size env
///< vars as needed.

/**
* @brief Set opaque user header bytes from raw pointer.
*
* @param[in] data pointer to input bytes, may be `nullptr` iff `size == 0`.
* @param[in] size number of bytes in input.
*/
void setUserHeader(const void* data, size_t size)
{
if (size > 0 && data == nullptr)
throw std::invalid_argument(
"AmSendParams::setUserHeader received null data with non-zero size");
userHeader.resize(size);
if (size > 0) memcpy(userHeader.data(), data, size);
}

/**
* @brief Convenience overload to set user header from string-like views.
*
* @param[in] data view of opaque bytes.
*/
void setUserHeader(std::string_view data) { setUserHeader(data.data(), data.size()); }
};

/**
* @brief Serialized form of a remote key.
*
Expand Down
39 changes: 33 additions & 6 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,41 @@ std::shared_ptr<Request> Endpoint::amSend(
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto params = AmSendParams{};
params.memoryType = memoryType;
params.receiverCallbackInfo = receiverCallbackInfo;

return amSend(buffer, length, params, enablePythonFuture, callbackFunction, callbackData);
}

std::shared_ptr<Request> Endpoint::amSend(const void* const buffer,
const size_t length,
const AmSendParams& params,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(
createRequestAm(endpoint,
data::AmSend(buffer, length, memoryType, receiverCallbackInfo),
enablePythonFuture,
callbackFunction,
callbackData));
return registerInflightRequest(createRequestAm(endpoint,
data::AmSend(buffer, length, params),
enablePythonFuture,
callbackFunction,
callbackData));
}

std::shared_ptr<Request> Endpoint::amSend(std::vector<ucp_dt_iov_t> iov,
const AmSendParams& params,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestAm(endpoint,
data::AmSend(std::move(iov), params),
enablePythonFuture,
callbackFunction,
callbackData));
}

std::shared_ptr<Request> Endpoint::amRecv(const bool enablePythonFuture,
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/internal/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <ucxx/typedefs.h>

#include <memory>
#include <utility>
#include <vector>

namespace ucxx {

Expand All @@ -18,11 +20,15 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData,
ucp_ep_h ep,
std::shared_ptr<RequestAm> request,
std::shared_ptr<Buffer> buffer,
AmReceiverCallbackType receiverCallback)
AmReceiverCallbackType receiverCallback,
std::vector<std::byte> userHeader)
: _amData(amData), _ep(ep), _request(request)
{
std::visit(data::dispatch{
[this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; },
[this, buffer, &userHeader](data::AmReceive& amReceive) {
amReceive._buffer = buffer;
amReceive._userHeader = std::move(userHeader);
},
[](auto) { throw std::runtime_error("Unreachable"); },
},
_request->_requestData);
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,6 @@ const std::string& Request::getOwnerString() const { return _ownerString; }

std::shared_ptr<Buffer> Request::getRecvBuffer() { return nullptr; }

std::string Request::getRecvHeader() { return {}; }

} // namespace ucxx
Loading
Loading