2020#include < cuda.h>
2121#endif // USE_PI_CUDA
2222
23+ #ifdef USE_PI_ROCM
24+ #include < hip/hip_runtime.h>
25+ #endif // USE_PI_ROCM
26+
2327#include < algorithm>
2428#include < cstdlib>
2529#include < cstring>
@@ -32,7 +36,7 @@ static const std::string help =
3236 " Help\n "
3337 " Example: ./get_device_count_by_type cpu opencl\n "
3438 " Supported device types: cpu/gpu/accelerator/default/all\n "
35- " Supported backends: PI_CUDA/PI_OPENCL/PI_LEVEL_ZERO \n "
39+ " Supported backends: PI_CUDA/PI_ROCM/ PI_OPENCL/PI_LEVEL_ZERO \n "
3640 " Output format: <number_of_devices>:<additional_Information>" ;
3741
3842// Return the string with all characters translated to lower case.
@@ -224,6 +228,49 @@ static bool queryCUDA(cl_device_type deviceType, cl_uint &deviceCount,
224228#endif
225229}
226230
231+ static bool queryROCm (cl_device_type deviceType, cl_uint &deviceCount,
232+ std::string &msg) {
233+ deviceCount = 0u ;
234+ #ifdef USE_PI_ROCM
235+ switch (deviceType) {
236+ case CL_DEVICE_TYPE_DEFAULT: // Fall through.
237+ case CL_DEVICE_TYPE_ALL: // Fall through.
238+ case CL_DEVICE_TYPE_GPU: {
239+ int count = 0 ;
240+ hipError_t err = hipGetDeviceCount (&count);
241+ if (err != hipSuccess || count < 0 ) {
242+ msg = " ERROR: ROCm error querying device count" ;
243+ return false ;
244+ }
245+ if (count < 1 ) {
246+ msg = " ERROR: ROCm no device found" ;
247+ return false ;
248+ }
249+ deviceCount = static_cast <cl_uint>(count);
250+ #if defined(__HIP_PLATFORM_AMD__)
251+ msg = " rocm-amd " ;
252+ #elif defined(__HIP_PLATFORM_NVIDIA__)
253+ msg = " rocm-nvidia " ;
254+ #else
255+ #error ("Must define one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
256+ #endif
257+ msg += deviceTypeToString (deviceType);
258+ return true ;
259+ } break ;
260+ default :
261+ msg = " WARNING: ROCm unsupported device type " ;
262+ msg += deviceTypeToString (deviceType);
263+ return true ;
264+ }
265+ #else
266+ (void )deviceType;
267+ msg = " ERROR: ROCm not supported" ;
268+ deviceCount = 0u ;
269+
270+ return false ;
271+ #endif
272+ }
273+
227274int main (int argc, char *argv[]) {
228275 if (argc < 3 ) {
229276 std::cout << " 0:ERROR: Please set a device type and backend to find"
@@ -264,6 +311,8 @@ int main(int argc, char *argv[]) {
264311 querySuccess = queryLevelZero (deviceType, deviceCount, msg);
265312 } else if (backend == " cuda" || backend == " pi_cuda" ) {
266313 querySuccess = queryCUDA (deviceType, deviceCount, msg);
314+ } else if (backend == " rocm" || backend == " pi_rocm" ) {
315+ querySuccess = queryROCm (deviceType, deviceCount, msg);
267316 } else {
268317 msg = " ERROR: Unknown backend " + backend + " \n " + help + " \n " ;
269318 }
0 commit comments