Reorganize delayed submission with operation-specific data and expose tag mask#121
Conversation
wence-
left a comment
There was a problem hiding this comment.
Broadly speaking I think this looks good. I just wonder if there are some ways we can design a more type-safe interface to avoid accidentally constructing invalid send/recv calls and/or invalid delayed submissions.
cpp/include/ucxx/endpoint.h
Outdated
| ucp_tag_t tag, | ||
| ucp_tag_t tagMask, |
There was a problem hiding this comment.
suggestion Can we use a strong type for the tag and tagmask? Right now it is easy to accidentally get the tag and the mask in the wrong order (since the compiler will happily do integral conversions this also applies to the length).
Concretely, if we use the scoped opaque enum pattern:
enum Tag : ucp_tag_t {};
enum Mask : ucp_tag_t {};
auto tagRecv(void *buf, size_t length, Tag tag, Mask mask, ...)
{
// auto conversion from Tag and Mask to their underlying type works.
return ucp_tag_recv_nb(buf, length, tag, mask);
}
// compile error, no implicit conversion from integral type to tag/mask
tagRecv(buf, length, 10, 20);
// compile error with mismatching argument order.
tagRecv(buf, length, Mask{20}, Tag{20});
There was a problem hiding this comment.
This is an excellent suggestion. I've been meaning to work on something like that, starting with the bool send we have in requests which is probably the worst such case in UCXX currently, but tags should do the same.
This is now done via d706c32, as a bonus I attempted to do the same on the Python side via 9dac813. I'm not confident this is the best for Python yet, WDYT?
cpp/include/ucxx/request.h
Outdated
|
|
||
| namespace ucxx { | ||
|
|
||
| static constexpr ucp_tag_t TagMaskFull = UINT64_MAX; |
There was a problem hiding this comment.
This would then be static constexpr Tag FULL{UINT64_MAX} (modulo naming).
cpp/benchmarks/perftest.cpp
Outdated
| std::vector<std::shared_ptr<ucxx::Request>> requests = { | ||
| endpoint->tagSend((*bufferMap)[SEND].data(), app_context.message_size, (*tagMap)[SEND]), | ||
| endpoint->tagRecv((*bufferMap)[RECV].data(), app_context.message_size, (*tagMap)[RECV])}; | ||
| endpoint->tagRecv((*bufferMap)[RECV].data(), app_context.message_size, (*tagMap)[RECV], -1)}; |
There was a problem hiding this comment.
Prefer to use the TagMaskFull symbol rather than the magic number -1.
cpp/benchmarks/perftest.cpp
Outdated
| (*wireupBufferMap)[RECV].size() * sizeof(int), | ||
| (*tagMap)[RECV])); | ||
| (*tagMap)[RECV], | ||
| -1)); |
There was a problem hiding this comment.
Likewise here and throughout.
cpp/src/request_tag.cpp
Outdated
| _delayedSubmission->_buffer, | ||
| _delayedSubmission->_length, | ||
| _delayedSubmission->_tag, | ||
| *_delayedSubmission->_data._tag, |
There was a problem hiding this comment.
question: Is anyone checking that _data._tag.has_value()? If not then this is potentially UB if we derefence std::nullopt.
| const std::optional<ucs_memory_type_t> _memoryType; ///< Memory type used on the operation | ||
| const std::optional<ucp_tag_t> _tag; ///< Tag to match | ||
| const std::optional<ucp_tag_t> _tagMask; ///< Tag mask to use | ||
| }; |
There was a problem hiding this comment.
issue/question:
I don't think this interface for wrapping up the packet of delayed submission data is very type safe (or memory efficient).
The operationType is, AIUI, a tag that says whether we're doing a tag send/recv or an active message.
If we're using active messages, then it is invalid to provide a tag and tagmask argument. But one must provide a memory type. Conversely, for tag or tagmulti, one must provide both a tag and a tagmask, but must not provide memory type.
Further, if the operation type is Undefined then we don't have any information and the object is unusable.
If I understand things correctly, then the data is one of (I'm ignoring Stream for now, but I think pattern will still apply):
struct AM { ucs_memory_type_t mem_type; };
// If these two are the same for all purposes they don't need to be different structs.
struct Tag { ucp_tag_t tag; ucp_tag_t mask; };
struct TagMulti {ucp_tag_t tag; ucp_tag_t mask; };
And then this data is a std::variant<AM, Tag, TagMulti> data;
So you could imagine something like:
struct One {
int x;
};
struct Two {
int x;
float y;
};
struct Data {
std::variant<One, Two> var;
Data(decltype(var) v) : var{v} {};
One one() { return std::get<One>(var); }
Two two() { return std::get<Two>(var); }
};
int main(void) {
auto one = Data(One{1});
auto two = Data(Two{1, 2});
// throws std::bad_variant_access since two does not contain a One
std::cout << two.one().x << std::endl;
return 0;
}The constructor of the data must know whether they are a tag/tagmulti/am or stream person so they can plonk the correct concrete struct in, and then they can also attempt to pull the correct variant out, and catch if the wrong type arrives. WDYT?
(It's not as ergonomic as Rust enums + pattern matching 😞 )
There was a problem hiding this comment.
Another great suggestion, in fact I initially tried getting something like that with std::variant but I failed to see the scenario you proposed (probably due to the lack of getters, but I don't recall for sure). Now done via 5037e8f.
cpp/src/delayed_submission.cpp
Outdated
| } else if (_operationType != DelayedSubmissionOperationType::AM) { | ||
| if (_memoryType) throw std::runtime_error("Specifying memoryType value requires AM operation."); | ||
| } | ||
| } |
There was a problem hiding this comment.
This does the "empty space" validation: you're not allowed to provide this combination of things. Does it do the "positive space" validation? "You have provided the correct set of things". e.g. it seems that if the operation is AM, then I should be on the hook to provide a memoryType that is not std::nullopt, but that is not enforced.
There was a problem hiding this comment.
I think this is also not entirely relevant anymore or implicitly addressed by 5037e8f, but please let me know if you think we should improve more still.
There was a problem hiding this comment.
OK, I think I understand this a bit better now, I think the valid objects for a DelayedSubmission are as follows:
struct AMSend {
void *buffer; // could wrap this pair of (buffer, len) in a struct too if you really wanted to
size_t len; // but not sure it's worth it
ucp_memory_type_t memtype;
};
struct AMRecv {
void *buffer;
size_t len;
};
struct TagSend { // for both tag and tagmulti
void *buffer;
size_t len;
ucp_tag_t tag;
};
struct TagRecv {
void *buffer;
size_t len;
ucp_tag_t tag; // we can use enums for this too instead, but it doesn't change the model.
ucp_tag_t mask;
};
Now a delayed submission is:
struct DelayedSubmission {
using type_t = std::variant<std::monostate, AMSend, AMRecv, TagSend, TagRecv>;
type_t info{};
DelayedSubmission() = delete;
DelayedSubmission(type_t info) : info{info} {};
}
...
// Now a requestAM construction for a delayed submission looks like
std::make_shared<DelayedSubmission>(AMSend{buffer, len, memtype});
// For a receive
std::make_shared<DelayedSubmission>(AMRecv{buffer, len});
// For a tag (or tagmulti) send
std::make_shared<DelayedSubmission>(TagSend{buffer, len, tag});
// For a tag (or tagmulti) recv
std::make_shared<DelayedSubmission>(TagRecv{buffer, len, tag, mask});
And we can have a tighter type of RequestTag:
RequestTag(endpoint, std::variant<TagSend, TagRecv> info, ...)
...
My idea here (not really mine, I just copy from elsewhere) is to bundle up the things that we need for each type of send/receive operation, and make it "impossible" to provide a badly formed set of information. That way we don't need to do validation of the arguments in case someone provided an invalid combo.
cpp/src/request_tag.cpp
Outdated
| ? DelayedSubmissionData(DelayedSubmissionOperationType::Tag, | ||
| transferDirection, | ||
| DelayedSubmissionTagSend(buffer, length, tag)) | ||
| : DelayedSubmissionData(DelayedSubmissionOperationType::Tag, | ||
| transferDirection, | ||
| DelayedSubmissionTagReceive(buffer, length, tag, tagMask))), |
There was a problem hiding this comment.
There should be no need for the DelayedSubmissionOperationType and transferDirection arguments any more, because that information is encoded in the particular type in the variant, so this is sending in redundant information which may not match.
cpp/src/request_tag.cpp
Outdated
| ¶m); | ||
| auto tagSend = _delayedSubmission->_data.getTagSend(); | ||
| request = ucp_tag_send_nbx( | ||
| _endpoint->getHandle(), tagSend._buffer, tagSend._length, tagSend._tag, ¶m); |
There was a problem hiding this comment.
Because the _data is now a tagged union, it has the information needed to do the dispatch.
There are a few ways to do this:
auto data = delayedSubmission->data;
if (std::holds_alternative<TagSend>(data)) {
auto info = data.getTagSend();
...
else if (std::holds_alternative<TagRecv>(data)) {
auto info = data.getTagRecv();
...
} else {
throw ucxx::logic_error("Invalid delayed submission for tag send/recv"); // or whatever the message...
}
Note the problem here is that we must unwrap things manually. However, we can instead use std::visit with operator overloading:
struct Dispatch {
void operator()(TagSend const &ts) { ucp_tag_send_nbx(...); };
void operator()(TagRecv const &tr) { ucp_tag_recv_nbx(...); };
template<typename T>
void operator()(T const &) { throw std::logic_error("Unreachable!"): }
}
...
std::visit(Dispatch{}, data);
Or, my favourite since you don't need to build a struct, but needs some template magic https://www.cppstories.com/2019/02/2lines3featuresoverload.html/
template <class... Ts> struct overload : Ts... {
using Ts::operator()...;
};
template <class... Ts> overload(Ts...) -> overload<Ts...>;
void RequestTag::request()
{
ucp_request_param_t param = ...;
void * request = std::visit(overload {
[¶m](TagSend const &ts) -> void * { param.cb.send = tagSendCallback; return ucp_tag_send_nbx(...); },
[¶m](TagRecv const &tr) -> void * { param.cb.recv = tagRecvCallback; return ucp_tag_recv_nbx(...); },
[](auto const &) -> void * { throw std::logic_error("Unreachable!"); }
}, delayedSubmission->data);
}Basically, this is poor-man's pattern matching with pretty ugly syntax. But, it becomes a compile error if you don't handle all variant types, good! It also keeps everything in one place.
cpp/src/request_tag.cpp
Outdated
| try { | ||
| auto tagSend = _delayedSubmission->_data.getTagSend(); | ||
| log(tagSend._buffer, tagSend._length, tagSend._tag, TagMaskFull); | ||
| } catch (const std::bad_variant_access& e) { | ||
| try { | ||
| auto tagReceive = _delayedSubmission->_data.getTagReceive(); | ||
| log(tagReceive._buffer, tagReceive._length, tagReceive._tag, tagReceive._tagMask); | ||
| } catch (const std::bad_variant_access& e) { | ||
| ucxx_error("Impossible to get transfer data."); | ||
| } | ||
| } |
There was a problem hiding this comment.
This pattern could then also be wrapped up in the same std::visit pattern.
To account for stronger typing, this change removes the `DelayedSubmission` class in favor of new request-specific data types. The types are then combined with `std::variant` and utilize `std::visit` to choose the type to work on and gather/process request-specific data.
cpp/src/request_am.cpp
Outdated
| "Receiving active messages must be handled by the worker's callback"); | ||
| } | ||
| std::visit(data::dispatch{ | ||
| [this](const data::AmSend& amSend) { |
There was a problem hiding this comment.
Here we must get data::AmSend by reference because we pass amSend._memoryType by reference to ucp_am_send_nbx. I'm not sure what should be the safer and cleaner way, but perhaps we should visit objects always by reference instead of by value as is done elsewhere (one notable exception is https://github.com/rapidsai/ucxx/pull/121/files#diff-637c52f56b253c39f497412ac96ffbb95ba303c20f6528f9f66e902601a1338dR21) when we need to modify the object in ) and make sure they're const where we don't need to modify it?
There was a problem hiding this comment.
Here we must get
data::AmSendby reference because we passamSend._memoryTypeby reference toucp_am_send_nbxwhich may not complete sychronously. I'm not sure what should be the safer and cleaner way, but perhaps we should visit objects always by reference instead of by value as is done elsewhere (one notable exception is https://github.com/rapidsai/ucxx/pull/121/files#diff-637c52f56b253c39f497412ac96ffbb95ba303c20f6528f9f66e902601a1338dR21) when we need to modify the object in ) and make sure they'reconstwhere we don't need to modify it?
There was a problem hiding this comment.
This is because we need the header buffer to live for as long as the lifetime of the request returned by ucp_am_send_nbx I think?
If we set .flags = UCP_AM_SEND_FLAG_REPLY | UCP_AM_SEND_FLAG_COPY_HEADER then UCX will memcpy the header and it will only need to be live until the call to ucp_am_send_nbx completes.
That might be a cleaner way to handle things.
There was a problem hiding this comment.
Yes, thanks for pointing that out. I didn't check whether there was a flag to do that. In the future that might not be the case anymore though, I made the choice of ucs_memory_type_t to act as header here for simplicity, ideally we would switch to some type that can be customized by the user and not anymore limit active messages to just sending/receiving a single hardcoded buffer type and give the freedom for the user to transfer complex types too. One "user" that could benefit from that is multi-buffer transfers, currently it's only supported by tag but I think we could make a "multi-buffer type" that we could send and receive via a single active message UCX call (I haven't thought too much about it though so complications may exist that make my idea not applicable).
There was a problem hiding this comment.
Adding UCP_AM_SEND_FLAG_COPY_HEADER in fc63a0e .
| std::visit(data::dispatch{ | ||
| [this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; }, | ||
| [](auto arg) { throw std::runtime_error("Unreachable"); }, | ||
| }, | ||
| _request->_requestData); |
There was a problem hiding this comment.
It might (?) be cleaner in this case to write:
if (std::holds_alternative_v<data::AmReceive>(_request->_requestData)) {
...
} else {
throw std::runtime_error("Unreachable!");
}
?
There was a problem hiding this comment.
No objections either way from me, but I'd prefer if we can keep things more or less consistent everywhere. See #121 (comment).
There was a problem hiding this comment.
Let's leave it as is then, the consistency angle makes sense.
wence-
left a comment
There was a problem hiding this comment.
Overall this looks great! Sorry that it is so much typing to adapt.
I have a few suggestions around cleaning up some of the places where we dispatch.
I think part of the verbosity is because we're now trying to glue together two slightly disparate typing "styles": inheritance vs. interfaces. However, I would like to play around a bit with the proposed new code to see if I can figure out a tidier way to do things going forward: I don't think that it makes sense to hold this PR up for that.
cpp/src/request_tag.cpp
Outdated
| return; | ||
| } | ||
| bool terminate = false; | ||
| std::visit(data::dispatch{ |
There was a problem hiding this comment.
suggestion, we're allowed to return from std::visit so:
bool terminate = std::visit(..., request_data);
if (terminate) return;
cpp/src/request_tag_multi.cpp
Outdated
| }, | ||
| [](auto arg) { | ||
| throw std::runtime_error("Unreachable"); | ||
| return std::shared_ptr<RequestTagMulti>(nullptr); |
There was a problem hiding this comment.
question: I guess this is so argument deduction works. Is it cleaner to write:
[](auto arg) -> decltype(req) { throw std::runtime_error("Unreachable!"); }
?
cpp/src/request_tag_multi.cpp
Outdated
| std::visit( | ||
| data::dispatch{ | ||
| [&tagPair](data::TagMultiReceive tagMultiReceive) { | ||
| tagPair = std::make_pair(tagMultiReceive._tag, tagMultiReceive._tagMask); | ||
| }, | ||
| [&methodName](auto arg) { | ||
| throw std::runtime_error(methodName + "() can only be called by a receive request."); | ||
| }, | ||
| }, | ||
| requestData); | ||
|
|
||
| return tagPair; |
There was a problem hiding this comment.
suggestion
Would a try/catch here be cleaner?
try {
auto data = std::get<data::TagMultiReceive>(requestData);
return {data._tag, data._tagMask}
} catch (std::bad_variant_error) {
throw std::runtimeError(...);
}
WDYT?
There was a problem hiding this comment.
Perhaps slightly, but we would then break consistency, xref #121 (comment) .
cpp/src/request_tag_multi.cpp
Outdated
| std::pair<Tag, TagMask> tagPair; | ||
| std::visit(data::dispatch{ | ||
| [&tagPair](data::TagMultiSend tagMultiSend) { | ||
| tagPair = std::make_pair(tagMultiSend._tag, TagMaskFull); | ||
| }, | ||
| [&tagPair](data::TagMultiReceive tagMultiReceive) { | ||
| tagPair = std::make_pair(tagMultiReceive._tag, tagMultiReceive._tagMask); | ||
| }, | ||
| [](auto arg) { throw std::runtime_error("Unreachable"); }, | ||
| }, | ||
| _requestData); |
There was a problem hiding this comment.
suggestion:
std::pair<Tag, TagMask> tagPair = std::visit(data::dispatch {
[](data::TagMultiSend tag) { return {tag._tag, TagMaskFull}; },
[](data::TagMultiReceive tag) { return tag._tag, tag._tagMask}; },
[](auto) -> decltype(tagPair) { throw std::runtime_error("Unreachable"); },
}, _requestData);
Uses returns and also the automatic initializer-list construction of std::pair.
cpp/src/request_tag_multi.cpp
Outdated
| tagMultiSend._tag, | ||
| _isFilled); | ||
| }, | ||
| [](auto arg) { throw std::runtime_error("send() can only be called by a sendrequest."); }, |
There was a problem hiding this comment.
suggestion (throughout), where we don't need the argument name, just write (auto)
cpp/src/request_tag.cpp
Outdated
| const data::RequestData requestData, | ||
| const std::string operationName, |
There was a problem hiding this comment.
Can we construct the operation name from the request info:
: Request(endpointOrWorker, request, std::visit(dispatch { [](TagSend) { return "tagSend"; }, [](TagRecv) { return "TagRecv"; } }, request, ...) {}
| std::shared_ptr<RequestTag> req = | ||
| std::visit(data::dispatch{ | ||
| [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( | ||
| data::TagSend tagSend) { | ||
| return std::shared_ptr<RequestTag>(new RequestTag(endpointOrWorker, | ||
| tagSend, | ||
| "tagSend", | ||
| enablePythonFuture, | ||
| callbackFunction, | ||
| callbackData)); | ||
| }, | ||
| [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( | ||
| data::TagReceive tagReceive) { | ||
| return std::shared_ptr<RequestTag>(new RequestTag(endpointOrWorker, | ||
| tagReceive, | ||
| "tagRecv", | ||
| enablePythonFuture, | ||
| callbackFunction, | ||
| callbackData)); | ||
| }, | ||
| [](auto arg) { | ||
| throw std::runtime_error("Unreachable"); | ||
| return std::shared_ptr<RequestTag>(nullptr); | ||
| }, | ||
| }, | ||
| requestData); |
There was a problem hiding this comment.
And then this could wouldn't need to exist, because the only part that is different between the cases is "tagRecv" vs "tagSend".
We could just write:
auto req = std::shard_ptr<RequestTag>(new RequestTaG(endpointOrWorker, requestData, enablePython...));
WDYT?
There was a problem hiding this comment.
Maybe I'm missing something, but how are you suggesting we set the "tagRecv"/`"tagSend" value here?
Additionally set default values.
cpp/include/ucxx/request_data.h
Outdated
| template <class... Ts> | ||
| dispatch(Ts...) -> dispatch<Ts...>; | ||
| template <class... Ts> | ||
| dispatch(Ts&...) -> dispatch<Ts...>; |
There was a problem hiding this comment.
I admit to not know a lot of template deduction hint magic. Did you need Ts... and Ts&... to handle overloads that had a combination of "by reference" and "by value" params?
There was a problem hiding this comment.
You're right, that shouldn't be needed. I'm sure at one point I had build errors but that was probably me having done something wrong. Removed now in 05fcb93 and it still compiles locally.
cpp/src/delayed_submission.cpp
Outdated
| @@ -1,4 +1,4 @@ | |||
| /** | |||
| /**: | |||
wence-
left a comment
There was a problem hiding this comment.
Thanks Peter, looks good!
|
Thanks @wence- ! |
|
/merge |
Add new class to organize operation-specific data, making members optional depending on what a specific transfer operation requires and checking their validity at construction time.
Expose tag mask to tag and multi-buffer tag APIs.