Skip to content

Commit c80fa69

Browse files
committed
use wrap namespace, clang-format and add comments
1 parent 49e9f54 commit c80fa69

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

xla/stream_executor/rocm/rocm_driver_wrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace wrap {
5151
static FuncPtrT loaded = []() -> FuncPtrT { \
5252
static const char *kName = TO_STR(hipSymbolName); \
5353
void *f; \
54-
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
54+
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
5555
tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \
5656
&f); \
5757
CHECK(s.ok()) << "could not find " << kName \
@@ -100,6 +100,7 @@ namespace wrap {
100100
__macro(hipGetDeviceCount) \
101101
__macro(hipGetDeviceProperties) \
102102
__macro(hipGetErrorString) \
103+
__macro(hipGetLastError) \
103104
__macro(hipGraphAddKernelNode) \
104105
__macro(hipGraphAddChildGraphNode) \
105106
__macro(hipGraphAddEmptyNode) \

xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,15 @@ absl::Status EnablePeerAccess(Context* from, Context* to) {
379379
wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
380380

381381
if (result == hipErrorPeerAccessAlreadyEnabled) {
382-
hipGetLastError();
382+
// hipGetLastError is used to reset per thread error state,
383+
// as hipGetLastError would get the recent error code since rocm7 even the
384+
// last call is successful.
385+
(void)wrap::hipGetLastError();
383386
} else if (result != hipSuccess) {
384387
return absl::InternalError(
385388
absl::StrFormat("failed to enable peer access from %d to %d: %s",
386389
from->device_ordinal(), to->device_ordinal(),
387-
hipGetErrorString(result)));
390+
wrap::hipGetErrorString(result)));
388391
}
389392

390393
return absl::OkStatus();

0 commit comments

Comments
 (0)