Skip to content

Commit d9d5299

Browse files
authored
fix hard-coded device properties and remove reserved memory (#351)
1 parent ff16e5a commit d9d5299

File tree

2 files changed

+27
-58
lines changed

2 files changed

+27
-58
lines changed

xla/stream_executor/rocm/rocm_context.cc

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -60,45 +60,6 @@ hipCtx_t CurrentContext() {
6060
return current;
6161
}
6262

63-
// Returns the amount of memory reserved by ROCm libraries.
64-
bool GetReservedMemory(uint64_t* reserve) {
65-
hipDeviceProp_t props;
66-
hipDevice_t dev;
67-
hipError_t res = wrap::hipGetDevice(&dev);
68-
69-
if (res != hipSuccess) {
70-
LOG(FATAL) << "failed to query current device: " << ToString(res);
71-
return false;
72-
}
73-
res = wrap::hipGetDeviceProperties(&props, dev);
74-
if (res != hipSuccess) {
75-
LOG(ERROR) << "failed to query device properties: " << ToString(res);
76-
return false;
77-
}
78-
79-
std::string gcnArchName = props.gcnArchName;
80-
auto compute_capability = RocmComputeCapability(gcnArchName);
81-
// On gfx90a, we hide 1 GB of GPU memory (512MB for gfx908) from TF,
82-
// to allow for late allocations by internal ROCm libraries
83-
// (e.g. rocBLAS alone needs~200 MB to put its kernels as of ROCm 4.1)
84-
const uint64_t RESERVED_GFX908 = 1048576 * 512;
85-
const uint64_t RESERVED_GFX9_X = 1048576 * 1024;
86-
const uint64_t RESERVED_GFX10_X = 1048576 * 512;
87-
const uint64_t RESERVED_GFX11_X = 1048576 * 512;
88-
if (compute_capability.gfx9_mi100()) {
89-
*reserve = RESERVED_GFX908;
90-
} else if (compute_capability.gfx9_mi200_or_later()) {
91-
*reserve = RESERVED_GFX9_X;
92-
} else if (compute_capability.gfx10_rx68xx() ||
93-
compute_capability.gfx10_rx69xx()) {
94-
*reserve = RESERVED_GFX10_X;
95-
} else if (compute_capability.gfx11()) {
96-
*reserve = RESERVED_GFX11_X;
97-
}
98-
99-
return true;
100-
}
101-
10263
} // namespace
10364

10465
// Returns the singleton ContextMap.
@@ -126,12 +87,7 @@ bool RocmContext::GetDeviceTotalMemory(hipDevice_t device, uint64_t* result) {
12687
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
12788
return false;
12889
}
129-
uint64_t reserve = 0;
130-
if (!GetReservedMemory(&reserve)) {
131-
LOG(ERROR) << "failed to reserved device memory for ROCm libraries";
132-
return false;
133-
}
134-
*result = value - reserve;
90+
*result = value;
13591
return true;
13692
}
13793

@@ -145,24 +101,17 @@ bool RocmContext::GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out) {
145101
return false;
146102
}
147103

148-
uint64_t reserve = 0;
149-
if (!GetReservedMemory(&reserve)) {
150-
LOG(ERROR) << "failed to reserved device memory for ROCm libraries";
151-
return false;
152-
}
153-
154104
VLOG(1) << "Device memory: " << total / 1048576 << " MB total, "
155-
<< free / 1048576 << " MB free, reserving " << reserve / 1048576
156-
<< " MB";
105+
<< free / 1048576 << " MB free";
157106

158107
// overflow check
159108
if (free > std::numeric_limits<int64_t>::max()) {
160109
LOG(ERROR) << "free memory (" << free << ") is overflow int64_t";
161110
return false;
162111
}
163112

164-
*free_out = free >= reserve ? free - reserve : 0;
165-
*total_out = total - reserve;
113+
*free_out = free;
114+
*total_out = total;
166115
return true;
167116
}
168117

xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {
306306
return absl::OkStatus();
307307
}
308308

309+
absl::StatusOr<int64_t> GetMaxRegistersPerMultiprocessor(hipDevice_t device) {
310+
return GetSimpleAttribute<int64_t>(device, hipDeviceAttributeMaxRegistersPerMultiprocessor);
311+
}
312+
309313
// Returns the device associated with the given device_ordinal.
310314
absl::StatusOr<hipDevice_t> GetDevice(int device_ordinal) {
311315
hipDevice_t device;
@@ -394,6 +398,19 @@ std::string GetPCIBusID(hipDevice_t device) {
394398
return pci_bus_id;
395399
}
396400

401+
402+
bool IsEccEnabled(hipDevice_t device, bool* result) {
403+
int value = -1;
404+
auto status = ToStatus(wrap::hipDeviceGetAttribute(
405+
&value, hipDeviceAttributeEccEnabled, device));
406+
if (!status.ok()) {
407+
LOG(ERROR) << "failed to query ECC status: " << status;
408+
return false;
409+
}
410+
*result = value;
411+
return true;
412+
}
413+
397414
bool GetDeviceProperties(hipDeviceProp_t* device_properties,
398415
int device_ordinal) {
399416
hipError_t res =
@@ -1015,8 +1032,11 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
10151032
desc.set_l2_cache_size(prop.l2CacheSize);
10161033
}
10171034

1018-
// No way to query ECC status from the API.
1019-
desc.set_ecc_enabled(false);
1035+
{
1036+
bool ecc_enabled = false;
1037+
IsEccEnabled(device, &ecc_enabled);
1038+
desc.set_ecc_enabled(ecc_enabled);
1039+
}
10201040

10211041
uint64_t device_memory_size = -1;
10221042
(void)RocmContext::GetDeviceTotalMemory(device, &device_memory_size);
@@ -1054,7 +1074,7 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
10541074
GetMaxThreadsPerMultiprocessor(device).value());
10551075
desc.set_registers_per_block_limit(GetMaxRegistersPerBlock(device).value());
10561076
desc.set_threads_per_warp(GetThreadsPerWarp(device).value());
1057-
desc.set_registers_per_core_limit(64 * 1024);
1077+
desc.set_registers_per_core_limit(GetMaxRegistersPerMultiprocessor(device).value());
10581078
desc.set_compile_time_toolkit_version(
10591079
SemanticVersion{HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH});
10601080
int32_t runtime_version;

0 commit comments

Comments
 (0)