Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions src/lapack/backends/rocsolver/rocsolver_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#define _ROCSOLVER_HELPER_HPP_

#include <CL/sycl.hpp>
#include <rocblas.h>
#include <rocsolver.h>
#include <rocblas/rocblas.h>
#include <rocsolver/rocsolver.h>
#include <hip/hip_runtime.h>
#include <complex>

Expand Down Expand Up @@ -82,15 +82,7 @@ void overflow_check(Index index, Next... indices) {
class rocsolver_error : virtual public std::runtime_error {
protected:
inline const char *rocsolver_error_map(rocblas_status error) {
switch (error) {
case rocblas_status_success: return "ROCBLAS_STATUS_SUCCESS";

case rocblas_status_invalid_value: return "ROCBLAS_STATUS_INVALID_VALUE";

case rocblas_status_internal_error: return "ROCBLAS_STATUS_INTERNAL_ERROR";

default: return "<unknown>";
}
return rocblas_status_to_string(error);
}

int error_number; ///< Error number
Expand Down Expand Up @@ -120,16 +112,7 @@ class rocsolver_error : virtual public std::runtime_error {
class hip_error : virtual public std::runtime_error {
protected:
inline const char *hip_error_map(hipError_t result) {
switch (result) {
case HIP_SUCCESS: return "HIP_SUCCESS";
case hipErrorNotInitialized: return "hipErrorNotInitialized";
case hipErrorInvalidContext: return "hipErrorInvalidContext";
case hipErrorInvalidDevice: return "hipErrorInvalidDevice";
case hipErrorInvalidValue: return "hipErrorInvalidValue";
case hipErrorMemoryAllocation: return "hipErrorMemoryAllocation";
case hipErrorLaunchOutOfResources: return "hipErrorLaunchOutOfResources";
default: return "<unknown>";
}
return hipGetErrorName(result);
}
int error_number; ///< error number
public:
Expand Down Expand Up @@ -271,14 +254,15 @@ inline int get_rocsolver_devinfo(sycl::queue &queue, sycl::buffer<int> &devInfo)

inline int get_rocsolver_devinfo(sycl::queue &queue, const int *devInfo) {
int dev_info_;
queue.wait();
queue.memcpy(&dev_info_, devInfo, sizeof(int));
queue.wait();
return dev_info_;
}

template <typename DEVINFO_T>
inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name,
const char *cufunc_name) {
queue.wait();
const int devinfo_ = get_rocsolver_devinfo(queue, devinfo);
if (devinfo_ > 0)
throw oneapi::mkl::lapack::computation_error(
Expand Down
Loading