@@ -295,29 +295,71 @@ struct sub_group {
295295 PI_INVALID_DEVICE);
296296#endif
297297 }
298-
298+ #ifdef __SYCL_DEVICE_ONLY__
299+ #ifdef __NVPTX__
299300 template <int N, typename T, access::address_space Space>
300301 sycl::detail::enable_if_t <
301- sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
302- N != 1 ,
302+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
303303 vec<T, N>>
304304 load (const multi_ptr<T, Space> src) const {
305- #ifdef __SYCL_DEVICE_ONLY__
306- #ifdef __NVPTX__
307305 vec<T, N> res;
308306 for (int i = 0 ; i < N; ++i) {
309307 res[i] = *(src.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]);
310308 }
311309 return res;
312- #else
310+ }
311+ #else // __NVPTX__
312+ template <int N, typename T, access::address_space Space>
313+ sycl::detail::enable_if_t <
314+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
315+ N != 1 && N != 3 && N != 16 ,
316+ vec<T, N>>
317+ load (const multi_ptr<T, Space> src) const {
313318 return sycl::detail::sub_group::load<N, T>(src);
314- #endif // __NVPTX__
315- #else
319+ }
320+
321+ template <int N, typename T, access::address_space Space>
322+ sycl::detail::enable_if_t <
323+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
324+ N == 16 ,
325+ vec<T, 16 >>
326+ load (const multi_ptr<T, Space> src) const {
327+ return {sycl::detail::sub_group::load<8 , T>(src),
328+ sycl::detail::sub_group::load<8 , T>(src +
329+ 8 * get_max_local_range ()[0 ])};
330+ }
331+
332+ template <int N, typename T, access::address_space Space>
333+ sycl::detail::enable_if_t <
334+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
335+ N == 3 ,
336+ vec<T, 3 >>
337+ load (const multi_ptr<T, Space> src) const {
338+ return {
339+ sycl::detail::sub_group::load<1 , T>(src),
340+ sycl::detail::sub_group::load<2 , T>(src + get_max_local_range ()[0 ])};
341+ }
342+
343+ template <int N, typename T, access::address_space Space>
344+ sycl::detail::enable_if_t <
345+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
346+ N == 1 ,
347+ vec<T, 1 >>
348+ load (const multi_ptr<T, Space> src) const {
349+ return sycl::detail::sub_group::load (src);
350+ }
351+ #endif // ___NVPTX___
352+ #else // __SYCL_DEVICE_ONLY__
353+ template <int N, typename T, access::address_space Space>
354+ sycl::detail::enable_if_t <
355+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
356+ vec<T, N>>
357+ load (const multi_ptr<T, Space> src) const {
316358 (void )src;
317359 throw runtime_error (" Sub-groups are not supported on host device." ,
318360 PI_INVALID_DEVICE);
319- #endif
320361 }
362+ #endif // __SYCL_DEVICE_ONLY__
321363
322364 template <int N, typename T, access::address_space Space>
323365 sycl::detail::enable_if_t <
@@ -337,25 +379,6 @@ struct sub_group {
337379#endif
338380 }
339381
340- template <int N, typename T, access::address_space Space>
341- sycl::detail::enable_if_t <
342- sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
343- N == 1 ,
344- vec<T, 1 >>
345- load (const multi_ptr<T, Space> src) const {
346- #ifdef __SYCL_DEVICE_ONLY__
347- #ifdef __NVPTX__
348- return src.get ()[get_local_id ()[0 ]];
349- #else
350- return sycl::detail::sub_group::load (src);
351- #endif // __NVPTX__
352- #else
353- (void )src;
354- throw runtime_error (" Sub-groups are not supported on host device." ,
355- PI_INVALID_DEVICE);
356- #endif
357- }
358-
359382#ifdef __SYCL_DEVICE_ONLY__
360383 // Method for decorated pointer
361384 template <typename T>
@@ -437,45 +460,63 @@ struct sub_group {
437460#endif
438461 }
439462
463+ #ifdef __SYCL_DEVICE_ONLY__
464+ #ifdef __NVPTX__
465+ template <int N, typename T, access::address_space Space>
466+ sycl::detail::enable_if_t <
467+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
468+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
469+ for (int i = 0 ; i < N; ++i) {
470+ *(dst.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]) = x[i];
471+ }
472+ }
473+ #else // __NVPTX__
474+ template <int N, typename T, access::address_space Space>
475+ sycl::detail::enable_if_t <
476+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
477+ N != 1 && N != 3 && N != 16 >
478+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
479+ sycl::detail::sub_group::store (dst, x);
480+ }
481+
440482 template <int N, typename T, access::address_space Space>
441483 sycl::detail::enable_if_t <
442484 sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
443485 N == 1 >
444486 store (multi_ptr<T, Space> dst, const vec<T, 1 > &x) const {
445- #ifdef __SYCL_DEVICE_ONLY__
446- #ifdef __NVPTX__
447- dst.get ()[get_local_id ()[0 ]] = x[0 ];
448- #else
449- store<T, Space>(dst, x);
450- #endif // __NVPTX__
451- #else
452- (void )dst;
453- (void )x;
454- throw runtime_error (" Sub-groups are not supported on host device." ,
455- PI_INVALID_DEVICE);
456- #endif
487+ sycl::detail::sub_group::store (dst, x);
457488 }
458489
459490 template <int N, typename T, access::address_space Space>
460491 sycl::detail::enable_if_t <
461492 sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
462- N != 1 >
463- store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
464- #ifdef __SYCL_DEVICE_ONLY__
465- #ifdef __NVPTX__
466- for (int i = 0 ; i < N; ++i) {
467- *(dst.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]) = x[i];
468- }
469- #else
470- sycl::detail::sub_group::store (dst, x);
493+ N == 3 >
494+ store (multi_ptr<T, Space> dst, const vec<T, 3 > &x) const {
495+ store<1 , T, Space>(dst, x.s0 ());
496+ store<2 , T, Space>(dst + get_max_local_range ()[0 ], {x.s1 (), x.s2 ()});
497+ }
498+
499+ template <int N, typename T, access::address_space Space>
500+ sycl::detail::enable_if_t <
501+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
502+ N == 16 >
503+ store (multi_ptr<T, Space> dst, const vec<T, 16 > &x) const {
504+ store<8 , T, Space>(dst, x.lo ());
505+ store<8 , T, Space>(dst + 8 * get_max_local_range ()[0 ], x.hi ());
506+ }
507+
471508#endif // __NVPTX__
472- #else
509+ #else // __SYCL_DEVICE_ONLY__
510+ template <int N, typename T, access::address_space Space>
511+ sycl::detail::enable_if_t <
512+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
513+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
473514 (void )dst;
474515 (void )x;
475516 throw runtime_error (" Sub-groups are not supported on host device." ,
476517 PI_INVALID_DEVICE);
477- #endif
478518 }
519+ #endif // __SYCL_DEVICE_ONLY__
479520
480521 template <int N, typename T, access::address_space Space>
481522 sycl::detail::enable_if_t <
0 commit comments