Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/aspects.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum class aspect {
host_debuggable = 32,
ext_intel_gpu_hw_threads_per_eu = 33,
ext_oneapi_cuda_async_barrier = 34,
ext_oneapi_bfloat16 = 35,
};

} // namespace sycl
Expand Down
2 changes: 2 additions & 0 deletions sycl/include/CL/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ typedef enum {
PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES = 0x11000,
PI_DEVICE_INFO_GPU_HW_THREADS_PER_EU = 0x10112,
PI_DEVICE_INFO_BACKEND_VERSION = 0x10113,
// Return true if bfloat16 data type is supported by device
PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16 = 0x1FFFF,
PI_EXT_ONEAPI_DEVICE_INFO_MAX_GLOBAL_WORK_GROUPS = 0x20000,
PI_EXT_ONEAPI_DEVICE_INFO_MAX_WORK_GROUPS_1D = 0x20001,
PI_EXT_ONEAPI_DEVICE_INFO_MAX_WORK_GROUPS_2D = 0x20002,
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/info/device_traits.def
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __SYCL_PARAM_TRAITS_SPEC(device, atomic_memory_order_capabilities,
std::vector<cl::sycl::memory_order>)
__SYCL_PARAM_TRAITS_SPEC(device, atomic_memory_scope_capabilities,
std::vector<cl::sycl::memory_scope>)
__SYCL_PARAM_TRAITS_SPEC(device, ext_oneapi_bfloat16, bool)
__SYCL_PARAM_TRAITS_SPEC(device, max_read_image_args, pi_uint32)
__SYCL_PARAM_TRAITS_SPEC(device, max_write_image_args, pi_uint32)
__SYCL_PARAM_TRAITS_SPEC(device, image2d_max_width, size_t)
Expand Down
3 changes: 2 additions & 1 deletion sycl/include/CL/sycl/info/info_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ enum class device : cl_device_info {
ext_oneapi_max_work_groups_2d = PI_EXT_ONEAPI_DEVICE_INFO_MAX_WORK_GROUPS_2D,
ext_oneapi_max_work_groups_3d = PI_EXT_ONEAPI_DEVICE_INFO_MAX_WORK_GROUPS_3D,
atomic_memory_scope_capabilities =
PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES
PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES,
ext_oneapi_bfloat16 = PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16,
};

enum class device_type : pi_uint64 {
Expand Down
11 changes: 11 additions & 0 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,17 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
return getInfo(param_value_size, param_value, param_value_size_ret,
capabilities);
}
case PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16: {
int major = 0;
cl::sycl::detail::pi::assertion(
cuDeviceGetAttribute(&major,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
device->get()) == CUDA_SUCCESS);

bool bfloat16 = (major >= 8) ? true : false;
return getInfo(param_value_size, param_value, param_value_size_ret,
bfloat16);
}
case PI_DEVICE_INFO_SUB_GROUP_SIZES_INTEL: {
// NVIDIA devices only support one sub-group size (the warp size)
int warpSize = 0;
Expand Down
1 change: 1 addition & 0 deletions sycl/plugins/hip/pi_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,7 @@ pi_result hip_piDeviceGetInfo(pi_device device, pi_device_info param_name,
case PI_DEVICE_INFO_GPU_EU_COUNT_PER_SUBSLICE:
case PI_DEVICE_INFO_GPU_HW_THREADS_PER_EU:
case PI_DEVICE_INFO_MAX_MEM_BANDWIDTH:
case PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16:
return PI_INVALID_VALUE;

default:
Expand Down
2 changes: 2 additions & 0 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,8 @@ pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
case PI_DEVICE_INFO_MAX_MEM_BANDWIDTH:
// currently not supported in level zero runtime
return PI_INVALID_VALUE;
case PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16:
return PI_INVALID_VALUE;

// TODO: Implement.
case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES:
Expand Down
2 changes: 2 additions & 0 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName,
std::memcpy(paramValue, &result, sizeof(cl_bool));
return PI_SUCCESS;
}
case PI_EXT_ONEAPI_DEVICE_INFO_BFLOAT16:
return PI_INVALID_VALUE;
case PI_DEVICE_INFO_IMAGE_SRGB: {
cl_bool result = true;
std::memcpy(paramValue, &result, sizeof(cl_bool));
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/device_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ bool device_impl::has(aspect Aspect) const {
return has_extension("cl_khr_fp16");
case aspect::fp64:
return has_extension("cl_khr_fp64");
case aspect::ext_oneapi_bfloat16:
return get_info<info::device::ext_oneapi_bfloat16>();
case aspect::int64_base_atomics:
return has_extension("cl_khr_int64_base_atomics");
case aspect::int64_extended_atomics:
Expand Down
21 changes: 21 additions & 0 deletions sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,22 @@ struct get_device_info<std::vector<memory_scope>,
}
};

// Specialization for bf16
template <> struct get_device_info<bool, info::device::ext_oneapi_bfloat16> {
static bool get(RT::PiDevice dev, const plugin &Plugin) {

bool result = false;

RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piDeviceGetInfo>(
dev, pi::cast<RT::PiDeviceInfo>(info::device::ext_oneapi_bfloat16),
sizeof(result), &result, nullptr);
if (Err != PI_SUCCESS) {
return false;
}
return result;
}
};

// Specialization for exec_capabilities, OpenCL returns a bitfield
template <>
struct get_device_info<std::vector<info::execution_capability>,
Expand Down Expand Up @@ -769,6 +785,11 @@ get_device_info_host<info::device::atomic_memory_scope_capabilities>() {
memory_scope::work_group, memory_scope::device, memory_scope::system};
}

template <>
inline bool get_device_info_host<info::device::ext_oneapi_bfloat16>() {
return false;
}

template <>
inline cl_uint get_device_info_host<info::device::max_read_image_args>() {
// current value is the required minimum
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,7 @@ _ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65809EEENS3_12param_traitsIS4_XT_
_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65810EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE65811EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE69632EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl6device8get_infoILNS0_4info6deviceE131071EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl6device9getNativeEv
_ZNK2cl4sycl6kernel11get_backendEv
_ZNK2cl4sycl6kernel11get_contextEv
Expand Down