@@ -181,8 +181,20 @@ class RingTransmit : public Proto<T, SubGroupSize> {
181181 retry |= recvMessages (messages, localScatterSink[peer][slot][wireId], flag);
182182 } while (sycl::any_of_group (sycl::ext::oneapi::this_work_item::get_sub_group (), retry));
183183
184- shuffleData (v);
185- accumMessages (v, messages);
184+ if constexpr (sizeof (T) > sizeof (flag)) {
185+ // restore data, accumulate, shuffle
186+ restoreData (messages);
187+ accumMessages (v, messages);
188+ shuffleData (v);
189+ }
190+ else {
191+ // datasize is smaller than flag,
192+ // no overlap between datatype and flag
193+ // can accumulate shuffled data:
194+ // less ops to perform, faster
195+ shuffleData (v);
196+ accumMessages (v, messages);
197+ }
186198 insertFlags (v, flag);
187199
188200 sendMessages (scatterSink[peer][slot][wireId], v);
@@ -208,8 +220,20 @@ class RingTransmit : public Proto<T, SubGroupSize> {
208220 retry |= recvMessages (messages, localScatterSink[peer][slot][wireId], flag);
209221 } while (sycl::any_of_group (sycl::ext::oneapi::this_work_item::get_sub_group (), retry));
210222
211- shuffleData (v);
212- accumMessages (v, messages);
223+ if constexpr (sizeof (T) > sizeof (flag)) {
224+ // restore data, accumulate, shuffle
225+ restoreData (messages);
226+ accumMessages (v, messages);
227+ shuffleData (v);
228+ }
229+ else {
230+ // datasize is smaller than flag,
231+ // no overlap between datatype and flag
232+ // can accumulate shuffled data:
233+ // less ops to perform, faster
234+ shuffleData (v);
235+ accumMessages (v, messages);
236+ }
213237
214238 insertFlags (v, flag);
215239 sendMessages (gatherSink[peer][slot][wireId], v);
@@ -294,8 +318,20 @@ class RingTransmit : public Proto<T, SubGroupSize> {
294318 retry |= recvMessages (messages, localScatterSink[peer][slot][wireId], flag);
295319 } while (sycl::any_of_group (sycl::ext::oneapi::this_work_item::get_sub_group (), retry));
296320
297- shuffleData (v);
298- accumMessages (v, messages);
321+ if constexpr (sizeof (T) > sizeof (flag)) {
322+ // restore data, accumulate, shuffle
323+ restoreData (messages);
324+ accumMessages (v, messages);
325+ shuffleData (v);
326+ }
327+ else {
328+ // datasize is smaller than flag,
329+ // no overlap between datatype and flag
330+ // can accumulate shuffled data:
331+ // less ops to perform, faster
332+ shuffleData (v);
333+ accumMessages (v, messages);
334+ }
299335
300336 insertFlags (v, flag);
301337 restoreData (v);
0 commit comments