Skip to content

Commit 3558c0b

Browse files
committed
Use std::variant in constructors
1 parent d4d44bb commit 3558c0b

11 files changed

Lines changed: 91 additions & 76 deletions

cpp/include/ucxx/constructors.h

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,28 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
5656
const bool enableFuture);
5757

5858
// Transfers
59-
std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
60-
const data::RequestData requestData,
61-
const bool enablePythonFuture,
62-
RequestCallbackUserFunction callbackFunction,
63-
RequestCallbackUserData callbackData);
64-
65-
std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> endpoint,
66-
const data::RequestData requestData,
67-
const bool enablePythonFuture);
68-
69-
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
70-
const data::RequestData requestData,
71-
const bool enablePythonFuture,
72-
RequestCallbackUserFunction callbackFunction,
73-
RequestCallbackUserData callbackData);
74-
75-
std::shared_ptr<RequestTagMulti> createRequestTagMulti(std::shared_ptr<Endpoint> endpoint,
76-
const data::RequestData requestData,
77-
const bool enablePythonFuture);
59+
std::shared_ptr<RequestAm> createRequestAm(
60+
std::shared_ptr<Endpoint> endpoint,
61+
const std::variant<data::AmSend, data::AmReceive> requestData,
62+
const bool enablePythonFuture,
63+
RequestCallbackUserFunction callbackFunction,
64+
RequestCallbackUserData callbackData);
65+
66+
std::shared_ptr<RequestStream> createRequestStream(
67+
std::shared_ptr<Endpoint> endpoint,
68+
const std::variant<data::StreamSend, data::StreamReceive> requestData,
69+
const bool enablePythonFuture);
70+
71+
std::shared_ptr<RequestTag> createRequestTag(
72+
std::shared_ptr<Component> endpointOrWorker,
73+
const std::variant<data::TagSend, data::TagReceive> requestData,
74+
const bool enablePythonFuture,
75+
RequestCallbackUserFunction callbackFunction,
76+
RequestCallbackUserData callbackData);
77+
78+
std::shared_ptr<RequestTagMulti> createRequestTagMulti(
79+
std::shared_ptr<Endpoint> endpoint,
80+
const std::variant<data::TagMultiSend, data::TagMultiReceive> requestData,
81+
const bool enablePythonFuture);
7882

7983
} // namespace ucxx

cpp/include/ucxx/request_am.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class RequestAm : public Request {
5252
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
5353
*/
5454
RequestAm(std::shared_ptr<Component> endpointOrWorker,
55-
const data::RequestData requestData,
55+
const std::variant<data::AmSend, data::AmReceive> requestData,
5656
const std::string operationName,
5757
const bool enablePythonFuture = false,
5858
RequestCallbackUserFunction callbackFunction = nullptr,
@@ -83,11 +83,12 @@ class RequestAm : public Request {
8383
*
8484
* @returns The `shared_ptr<ucxx::RequestAm>` object
8585
*/
86-
friend std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
87-
const data::RequestData requestData,
88-
const bool enablePythonFuture,
89-
RequestCallbackUserFunction callbackFunction,
90-
RequestCallbackUserData callbackData);
86+
friend std::shared_ptr<RequestAm> createRequestAm(
87+
std::shared_ptr<Endpoint> endpoint,
88+
const std::variant<data::AmSend, data::AmReceive> requestData,
89+
const bool enablePythonFuture,
90+
RequestCallbackUserFunction callbackFunction,
91+
RequestCallbackUserData callbackData);
9192

9293
virtual void populateDelayedSubmission();
9394

cpp/include/ucxx/request_data.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ dispatch(Ts...) -> dispatch<Ts...>;
197197
template <class... Ts>
198198
dispatch(Ts&...) -> dispatch<Ts...>;
199199

200+
template <class T>
201+
RequestData getRequestData(T t)
202+
{
203+
return std::visit([](auto arg) -> RequestData { return arg; }, t);
204+
}
205+
200206
} // namespace data
201207

202208
} // namespace ucxx

cpp/include/ucxx/request_stream.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class RequestStream : public Request {
3939
* subsequently notified.
4040
*/
4141
RequestStream(std::shared_ptr<Endpoint> endpoint,
42-
const data::RequestData requestData,
42+
const std::variant<data::StreamSend, data::StreamReceive> requestData,
4343
const std::string operationName,
4444
const bool enablePythonFuture = false);
4545

@@ -61,9 +61,10 @@ class RequestStream : public Request {
6161
*
6262
* @returns The `shared_ptr<ucxx::RequestStream>` object
6363
*/
64-
friend std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> endpoint,
65-
const data::RequestData requestData,
66-
const bool enablePythonFuture);
64+
friend std::shared_ptr<RequestStream> createRequestStream(
65+
std::shared_ptr<Endpoint> endpoint,
66+
const std::variant<data::StreamSend, data::StreamReceive> requestData,
67+
const bool enablePythonFuture);
6768

6869
virtual void populateDelayedSubmission();
6970

cpp/include/ucxx/request_tag.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RequestTag : public Request {
4747
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
4848
*/
4949
RequestTag(std::shared_ptr<Component> endpointOrWorker,
50-
const data::RequestData requestData,
50+
const std::variant<data::TagSend, data::TagReceive> requestData,
5151
const std::string operationName,
5252
const bool enablePythonFuture = false,
5353
RequestCallbackUserFunction callbackFunction = nullptr,
@@ -78,11 +78,12 @@ class RequestTag : public Request {
7878
*
7979
* @returns The `shared_ptr<ucxx::RequestTag>` object
8080
*/
81-
friend std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
82-
const data::RequestData requestData,
83-
const bool enablePythonFuture,
84-
RequestCallbackUserFunction callbackFunction,
85-
RequestCallbackUserData callbackData);
81+
friend std::shared_ptr<RequestTag> createRequestTag(
82+
std::shared_ptr<Component> endpointOrWorker,
83+
const std::variant<data::TagSend, data::TagReceive> requestData,
84+
const bool enablePythonFuture,
85+
RequestCallbackUserFunction callbackFunction,
86+
RequestCallbackUserData callbackData);
8687

8788
virtual void populateDelayedSubmission();
8889

cpp/include/ucxx/request_tag_multi.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class RequestTagMulti : public Request {
7676
* subsequently notified.
7777
*/
7878
RequestTagMulti(std::shared_ptr<Endpoint> endpoint,
79-
const data::RequestData requestData,
79+
const std::variant<data::TagMultiSend, data::TagMultiReceive> requestData,
8080
const std::string operationName,
8181
const bool enablePythonFuture);
8282

@@ -152,9 +152,10 @@ class RequestTagMulti : public Request {
152152
*
153153
* @returns Request to be subsequently checked for the completion and its state.
154154
*/
155-
friend std::shared_ptr<RequestTagMulti> createRequestTagMulti(std::shared_ptr<Endpoint> endpoint,
156-
const data::RequestData requestData,
157-
const bool enablePythonFuture);
155+
friend std::shared_ptr<RequestTagMulti> createRequestTagMulti(
156+
std::shared_ptr<Endpoint> endpoint,
157+
const std::variant<data::TagMultiSend, data::TagMultiReceive> requestData,
158+
const bool enablePythonFuture);
158159

159160
/**
160161
* @brief `ucxx::RequestTagMulti` destructor.

cpp/include/ucxx/worker.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ class Worker : public Component {
5555
std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
5656
nullptr}; ///< Collection of enqueued delayed submissions
5757

58-
friend std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
59-
const data::RequestData requestData,
60-
const bool enablePythonFuture,
61-
RequestCallbackUserFunction callbackFunction,
62-
RequestCallbackUserData callbackData);
58+
friend std::shared_ptr<RequestAm> createRequestAm(
59+
std::shared_ptr<Endpoint> endpoint,
60+
const std::variant<data::AmSend, data::AmReceive> requestData,
61+
const bool enablePythonFuture,
62+
RequestCallbackUserFunction callbackFunction,
63+
RequestCallbackUserData callbackData);
6364

6465
protected:
6566
bool _enableFuture{

cpp/src/request_am.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
namespace ucxx {
1818

19-
std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
20-
const data::RequestData requestData,
21-
const bool enablePythonFuture = false,
22-
RequestCallbackUserFunction callbackFunction = nullptr,
23-
RequestCallbackUserData callbackData = nullptr)
19+
std::shared_ptr<RequestAm> createRequestAm(
20+
std::shared_ptr<Endpoint> endpoint,
21+
const std::variant<data::AmSend, data::AmReceive> requestData,
22+
const bool enablePythonFuture = false,
23+
RequestCallbackUserFunction callbackFunction = nullptr,
24+
RequestCallbackUserData callbackData = nullptr)
2425
{
2526
std::shared_ptr<RequestAm> req = std::visit(
2627
data::dispatch{
@@ -49,28 +50,26 @@ std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
4950
};
5051
return worker->getAmRecv(endpoint->getHandle(), createRequest);
5152
},
52-
[](auto) -> decltype(req) { throw std::runtime_error("Unreachable"); },
5353
},
5454
requestData);
5555

5656
return req;
5757
}
5858

5959
RequestAm::RequestAm(std::shared_ptr<Component> endpointOrWorker,
60-
const data::RequestData requestData,
60+
const std::variant<data::AmSend, data::AmReceive> requestData,
6161
const std::string operationName,
6262
const bool enablePythonFuture,
6363
RequestCallbackUserFunction callbackFunction,
6464
RequestCallbackUserData callbackData)
65-
: Request(endpointOrWorker, requestData, operationName, enablePythonFuture)
65+
: Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture)
6666
{
6767
std::visit(data::dispatch{
6868
[this](data::AmSend amSend) {
6969
if (_endpoint == nullptr)
7070
throw ucxx::Error("An endpoint is required to send active messages");
7171
},
7272
[](data::AmReceive amReceive) {},
73-
[](auto) { throw std::runtime_error("Unreachable"); },
7473
},
7574
requestData);
7675

cpp/src/request_stream.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
namespace ucxx {
1414
RequestStream::RequestStream(std::shared_ptr<Endpoint> endpoint,
15-
const data::RequestData requestData,
15+
const std::variant<data::StreamSend, data::StreamReceive> requestData,
1616
const std::string operationName,
1717
const bool enablePythonFuture)
18-
: Request(endpoint, requestData, operationName, enablePythonFuture)
18+
: Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture)
1919
{
2020
std::visit(data::dispatch{
2121
[this](data::StreamSend streamSend) {
@@ -31,9 +31,10 @@ RequestStream::RequestStream(std::shared_ptr<Endpoint> endpoint,
3131
requestData);
3232
}
3333

34-
std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> endpoint,
35-
const data::RequestData requestData,
36-
const bool enablePythonFuture = false)
34+
std::shared_ptr<RequestStream> createRequestStream(
35+
std::shared_ptr<Endpoint> endpoint,
36+
const std::variant<data::StreamSend, data::StreamReceive> requestData,
37+
const bool enablePythonFuture = false)
3738
{
3839
std::shared_ptr<RequestStream> req =
3940
std::visit(data::dispatch{

cpp/src/request_tag.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
namespace ucxx {
1616

17-
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
18-
const data::RequestData requestData,
19-
const bool enablePythonFuture = false,
20-
RequestCallbackUserFunction callbackFunction = nullptr,
21-
RequestCallbackUserData callbackData = nullptr)
17+
std::shared_ptr<RequestTag> createRequestTag(
18+
std::shared_ptr<Component> endpointOrWorker,
19+
const std::variant<data::TagSend, data::TagReceive> requestData,
20+
const bool enablePythonFuture = false,
21+
RequestCallbackUserFunction callbackFunction = nullptr,
22+
RequestCallbackUserData callbackData = nullptr)
2223
{
2324
std::shared_ptr<RequestTag> req =
2425
std::visit(data::dispatch{
@@ -40,7 +41,6 @@ std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpoint
4041
callbackFunction,
4142
callbackData));
4243
},
43-
[](auto) -> decltype(req) { throw std::runtime_error("Unreachable"); },
4444
},
4545
requestData);
4646

@@ -54,20 +54,19 @@ std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpoint
5454
}
5555

5656
RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
57-
const data::RequestData requestData,
57+
const std::variant<data::TagSend, data::TagReceive> requestData,
5858
const std::string operationName,
5959
const bool enablePythonFuture,
6060
RequestCallbackUserFunction callbackFunction,
6161
RequestCallbackUserData callbackData)
62-
: Request(endpointOrWorker, requestData, operationName, enablePythonFuture)
62+
: Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture)
6363
{
6464
std::visit(data::dispatch{
6565
[this](data::TagSend tagSend) {
6666
if (_endpoint == nullptr)
6767
throw ucxx::Error("An endpoint is required to send tag messages");
6868
},
6969
[](data::TagReceive tagReceive) {},
70-
[](auto) { throw std::runtime_error("Unreachable"); },
7170
},
7271
requestData);
7372

0 commit comments

Comments
 (0)