Skip to content

Commit bf0971a

Browse files
amd-songpiaodraganmladjenovic
authored andcommitted
added rocm7 support to EnablePeerAccess (#347)
* added rocm7 support to EnablePeerAccess * use wrap namespace, clang-format and add comments
1 parent 84eb453 commit bf0971a

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

xla/stream_executor/rocm/rocm_driver_wrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace wrap {
5151
static FuncPtrT loaded = []() -> FuncPtrT { \
5252
static const char *kName = TO_STR(hipSymbolName); \
5353
void *f; \
54-
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
54+
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
5555
tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \
5656
&f); \
5757
CHECK(s.ok()) << "could not find " << kName \
@@ -100,6 +100,7 @@ namespace wrap {
100100
__macro(hipGetDeviceCount) \
101101
__macro(hipGetDeviceProperties) \
102102
__macro(hipGetErrorString) \
103+
__macro(hipGetLastError) \
103104
__macro(hipGraphAddKernelNode) \
104105
__macro(hipGraphAddChildGraphNode) \
105106
__macro(hipGraphAddEmptyNode) \

xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,16 @@ absl::Status EnablePeerAccess(Context* from, Context* to) {
378378
hipError_t result =
379379
wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
380380

381-
if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
381+
if (result == hipErrorPeerAccessAlreadyEnabled) {
382+
// hipGetLastError is used to reset per thread error state,
383+
// as hipGetLastError would get the recent error code since rocm7 even the
384+
// last call is successful.
385+
(void)wrap::hipGetLastError();
386+
} else if (result != hipSuccess) {
382387
return absl::InternalError(
383388
absl::StrFormat("failed to enable peer access from %d to %d: %s",
384389
from->device_ordinal(), to->device_ordinal(),
385-
ToString(result).c_str()));
390+
wrap::hipGetErrorString(result)));
386391
}
387392

388393
return absl::OkStatus();

0 commit comments

Comments
 (0)