Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/SOURCES
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ workitem/get_global_size.cl
workitem/get_group_id.cl
workitem/get_local_id.cl
workitem/get_local_size.cl
workitem/get_max_sub_group_size.cl
workitem/get_num_groups.cl
workitem/get_num_sub_groups.cl
workitem/get_sub_group_id.cl
workitem/get_sub_group_local_id.cl
workitem/get_sub_group_size.cl
images/image_helpers.ll
images/image.cl
15 changes: 15 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_max_sub_group_size.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() {
return 32;
// FIXME: warpsize is defined by NVVM IR but doesn't compile if used here
// return __nvvm_read_ptx_sreg_warpsize();
}
21 changes: 21 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_num_sub_groups.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();

_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups() {
// sreg.nwarpid returns number of warp identifiers, not number of warps
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint sg_size = __spirv_SubgroupMaxSize();
uint linear_size = size_z * size_y * size_x;
return (linear_size + sg_size - 1) / sg_size;
}
23 changes: 23 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_id.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId() {
// sreg.warpid is volatile and doesn't represent virtual warp index
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
size_t id_x = __spirv_LocalInvocationId_x();
size_t id_y = __spirv_LocalInvocationId_y();
size_t id_z = __spirv_LocalInvocationId_z();
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint sg_size = __spirv_SubgroupMaxSize();
return (id_z * size_y * size_x + id_y * size_x + id_x) / sg_size;
}
13 changes: 13 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_local_id.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId() {
return __nvvm_read_ptx_sreg_laneid();
}
26 changes: 26 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_size.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId();
_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups();
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize() {
if (__spirv_SubgroupId() != __spirv_NumSubgroups() - 1) {
return __spirv_SubgroupMaxSize();
} else {
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint linear_size = size_z * size_y * size_x;
uint uniform_groups = __spirv_NumSubgroups() - 1;
uint uniform_size = __spirv_SubgroupMaxSize() * uniform_groups;
return linear_size - uniform_size;
}
}
35 changes: 28 additions & 7 deletions sycl/include/CL/__spirv/spirv_vars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ SYCL_EXTERNAL size_t __spirv_LocalInvocationId_x();
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_y();
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_z();

SYCL_EXTERNAL uint32_t __spirv_SubgroupSize();
SYCL_EXTERNAL uint32_t __spirv_SubgroupMaxSize();
SYCL_EXTERNAL uint32_t __spirv_NumSubgroups();
SYCL_EXTERNAL uint32_t __spirv_SubgroupId();
SYCL_EXTERNAL uint32_t __spirv_SubgroupLocalInvocationId();

#else // __SYCL_NVPTX__

typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
Expand All @@ -56,6 +62,12 @@ __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInLocalInvocationId;
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInWorkgroupId;
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInGlobalOffset;

__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;

SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
return __spirv_BuiltInGlobalInvocationId.x;
}
Expand Down Expand Up @@ -126,14 +138,23 @@ SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_z() {
return __spirv_BuiltInLocalInvocationId.z;
}

#endif // __SYCL_NVPTX__
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupSize() {
return __spirv_BuiltInSubgroupSize;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupMaxSize() {
return __spirv_BuiltInSubgroupMaxSize;
}
SYCL_EXTERNAL inline uint32_t __spirv_NumSubgroups() {
return __spirv_BuiltInNumSubgroups;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupId() {
return __spirv_BuiltInSubgroupId;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupLocalInvocationId() {
return __spirv_BuiltInSubgroupLocalInvocationId;
}

__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumEnqueuedSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;
#endif // __SYCL_NVPTX__

#undef __SPIRV_VAR_QUALIFIERS

Expand Down
10 changes: 5 additions & 5 deletions sycl/include/CL/sycl/ONEAPI/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct sub_group {

id_type get_local_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupLocalInvocationId;
return __spirv_SubgroupLocalInvocationId();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -127,7 +127,7 @@ struct sub_group {

range_type get_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupSize;
return __spirv_SubgroupSize();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -136,7 +136,7 @@ struct sub_group {

range_type get_max_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupMaxSize;
return __spirv_SubgroupMaxSize();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -145,7 +145,7 @@ struct sub_group {

id_type get_group_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupId;
return __spirv_SubgroupId();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -163,7 +163,7 @@ struct sub_group {

range_type get_group_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInNumSubgroups;
return __spirv_NumSubgroups();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand Down
5 changes: 1 addition & 4 deletions sycl/test/sub_group/common.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// UNSUPPORTED: cuda
// CUDA compilation and runtime do not yet support sub-groups.
//
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
// RUN: %CPU_RUN_PLACEHOLDER %t.out
Expand Down Expand Up @@ -70,7 +67,7 @@ void check(queue &Queue, unsigned int G, unsigned int L) {
}
int main() {
queue Queue;
if (!core_sg_supported(Queue.get_device())) {
if (Queue.get_device().is_host()) {
std::cout << "Skipping test\n";
return 0;
}
Expand Down