Skip to content

Commit f3e170a

Browse files
authored
fix hardcoded max registers (#345)
1 parent d0ac0e6 commit f3e170a

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {
311311
return absl::OkStatus();
312312
}
313313

314+
absl::StatusOr<int64_t> GetMaxRegistersPerMultiprocessor(hipDevice_t device) {
315+
return GetSimpleAttribute<int64_t>(device, hipDeviceAttributeMaxRegistersPerMultiprocessor);
316+
}
317+
314318
// Returns the device associated with the given device_ordinal.
315319
absl::StatusOr<hipDevice_t> GetDevice(int device_ordinal) {
316320
hipDevice_t device;
@@ -1137,7 +1141,7 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
11371141
GetMaxThreadsPerMultiprocessor(device).value());
11381142
desc.set_registers_per_block_limit(GetMaxRegistersPerBlock(device).value());
11391143
desc.set_threads_per_warp(GetThreadsPerWarp(device).value());
1140-
desc.set_registers_per_core_limit(64 * 1024);
1144+
desc.set_registers_per_core_limit(GetMaxRegistersPerMultiprocessor(device).value());
11411145
desc.set_compile_time_toolkit_version(
11421146
SemanticVersion{HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH});
11431147
int32_t runtime_version;

0 commit comments

Comments
 (0)