diff --git a/build_tools/rocm/run_xla.sh b/build_tools/rocm/run_xla.sh index 1fc270db05f8a..7d94be10eefa3 100755 --- a/build_tools/rocm/run_xla.sh +++ b/build_tools/rocm/run_xla.sh @@ -79,13 +79,6 @@ if [[ $1 == "asan" ]]; then elif [[ $1 == "tsan" ]]; then SANITIZER_ARGS+=("--test_env=TSAN_OPTIONS=suppressions=$(realpath $(dirname $0))/tsan_ignore_list.txt::history_size=7:ignore_noninstrumented_modules=1") SANITIZER_ARGS+=("--config=tsan") - EXCLUDED_TESTS+=( - HloTest* - FunctionalHloRunnerTest* - TopkTest* - SimpleOptimizationTest.OptimizeModule - OutfeedInNestedComputationTest.OutfeedInConditional - ) fi bazel \ diff --git a/xla/stream_executor/rocm/rocm_stream.cc b/xla/stream_executor/rocm/rocm_stream.cc index 0ddfdccfe0eaf..2fe0f1e85348c 100644 --- a/xla/stream_executor/rocm/rocm_stream.cc +++ b/xla/stream_executor/rocm/rocm_stream.cc @@ -312,17 +312,40 @@ void InternalHostCallback(void* data) { absl::Status RocmStream::DoHostCallbackWithStatus( absl::AnyInvocable callback) { - auto callback_ptr = - new absl::AnyInvocable([cb = std::move(callback)]() mutable { - absl::Status s = std::move(cb)(); + auto callback_ptr = new absl::AnyInvocable( + [cb = std::move(callback), this]() mutable { + absl::Status s = (std::move(cb))(); + if (!s.ok()) { LOG(WARNING) << "Host callback failed: " << s; } + + // clang-format off + int num_pending_host_callbacks = num_pending_host_callbacks_.fetch_sub(1, std::memory_order_acq_rel) - 1; + // clang-format on + + // num_pending_host_callbacks_ can theoretically reach -1 if this + // callback gets executed before we increase the counter on the main + // thread. + if (num_pending_host_callbacks == 0) { + absl::MutexLock lock(&mutex_); + no_pending_host_callbacks_ = num_pending_host_callbacks_ <= 0; + } }); - return ToStatus( - wrap::hipLaunchHostFunc(stream_handle_, (hipHostFn_t)InternalHostCallback, - callback_ptr), - "unable to add host callback"); + + TF_RETURN_IF_ERROR(ToStatus(wrap::hipLaunchHostFunc( + stream_handle_, InternalHostCallback, callback_ptr))); + + int num_pending_host_callbacks = + num_pending_host_callbacks_.fetch_add(1, std::memory_order_acq_rel) + 1; + + if (num_pending_host_callbacks == 1) { + // num_pending_host_callbacks == 1 means we had no pending host callbacks + // before this one. + absl::MutexLock lock(&mutex_); + no_pending_host_callbacks_ = num_pending_host_callbacks_ <= 0; + } + return absl::OkStatus(); } namespace { @@ -356,11 +379,11 @@ absl::Status LaunchRocmKernel( function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem_bytes, stream, kernel_params, extra); } - TF_RETURN_IF_ERROR( - ToStatus(res, absl::StrCat("Failed to launch ROCm kernel: ", kernel_name, - "; grid: ", grid_dim_x, "x", grid_dim_y, "x", grid_dim_z, - "; block: ", block_dim_x, "x", block_dim_y, "x", block_dim_z, - "; shared_mem: ", shared_mem_bytes))); + TF_RETURN_IF_ERROR(ToStatus( + res, absl::StrCat("Failed to launch ROCm kernel: ", kernel_name, + "; grid: ", grid_dim_x, "x", grid_dim_y, "x", + grid_dim_z, "; block: ", block_dim_x, "x", block_dim_y, + "x", block_dim_z, "; shared_mem: ", shared_mem_bytes))); VLOG(2) << "successfully launched kernel"; return absl::OkStatus(); @@ -385,7 +408,10 @@ absl::Status LaunchRocmKernel( } // namespace absl::Status RocmStream::BlockHostUntilDone() { - return SynchronizeStream(executor_, stream_handle_); + TF_RETURN_IF_ERROR(SynchronizeStream(executor_, stream_handle_)); + absl::MutexLock lock(&mutex_); + mutex_.Await(absl::Condition(&no_pending_host_callbacks_)); + return absl::OkStatus(); } absl::Status RocmStream::LaunchKernel( diff --git a/xla/stream_executor/rocm/rocm_stream.h b/xla/stream_executor/rocm/rocm_stream.h index 977d27f3b7e13..433b51bc5ea71 100644 --- a/xla/stream_executor/rocm/rocm_stream.h +++ b/xla/stream_executor/rocm/rocm_stream.h @@ -93,6 +93,9 @@ class RocmStream : public StreamCommon { StreamExecutor* executor_; RocmEvent completed_event_; hipStream_t stream_handle_; + absl::Mutex mutex_; + bool no_pending_host_callbacks_ ABSL_GUARDED_BY(mutex_) = true; + std::atomic num_pending_host_callbacks_ = 0; }; } // namespace gpu