Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build_tools/rocm/run_xla.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,6 @@ bazel \
--action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \
--action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \
--run_under=//build_tools/ci:parallel_gpu_execute \
-- //xla/...
--test_env=MIOPEN_FIND_ENFORCE=5 \
--test_env=MIOPEN_FIND_MODE=1 \
-- //xla/... \
13 changes: 13 additions & 0 deletions build_tools/rocm/run_xla_multi_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ TAGS_FILTER="-requires-gpu-nvidia,-oss_excluded,-oss_serial"
UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{60,70,80,86,89,90}{,-only})"
TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}"

GPU_NAME=(`rocminfo | grep -m 1 gfx`)
GPU_NAME=${GPU_NAME[1]}

BAZEL_DISK_CACHE_SIZE=100G
BAZEL_DISK_CACHE_DIR="/tf/disk_cache/rocm-jaxlib-v0.6.0"

EXCLUDED_TESTS=(
CollectiveOpsTestE2E.MemcpyP2pLargeMessage
RaggedAllToAllTest/RaggedAllToAllTest.RaggedAllToAll_8GPUs_2ReplicasPerGroups/sync_decomposer
RaggedAllToAllTest/RaggedAllToAllTest.RaggedAllToAll_8GPUs_2ReplicasPerGroups/async_decomposer
)

bazel \
test \
--define xnn_enable_avxvnniint8=false \
Expand All @@ -90,6 +102,7 @@ bazel \
--action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \
--action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \
--action_env=NCCL_MAX_NCHANNELS=1 \
--test_filter=-$(IFS=: ; echo "${EXCLUDED_TESTS[*]}") \
-- //xla/tests:collective_ops_e2e_test \
//xla/tests:collective_ops_test \
//xla/tests:collective_pipeline_parallelism_test \
Expand Down
12 changes: 11 additions & 1 deletion xla/backends/gpu/codegen/emitters/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ReductionFusion : public EmitterBase {
return IndexingMap::GetUndefined();
}

int64_t WarpSize() const {
virtual int64_t WarpSize() const {
return ::xla::gpu::WarpSize(analysis_.device_info());
}

Expand Down Expand Up @@ -207,6 +207,11 @@ class ColumnReductionFusion : public ReductionFusion {
explicit ColumnReductionFusion(const HloFusionAnalysis& analysis,
SymbolicExprContext* symbolic_expr_context);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand All @@ -230,6 +235,11 @@ class SmallColumnReductionFusion : public ReductionFusion {
const HloFusionAnalysis& analysis,
SymbolicExprContext* symbolic_expr_context);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/gpu/codegen/emitters/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ namespace mt = ::mlir::tensor;
namespace mv = ::mlir::vector;

constexpr int kTileSize = 32;
constexpr int kNumRows = 4;
constexpr int kNumThreadsPerBlock = 128;
constexpr int kMaxVectorizedBytes = 4;
constexpr int kNumRows = 8;
constexpr int kNumThreadsPerBlock = kNumRows * kTileSize;
constexpr int kMaxVectorizedBytes = 16;

// Reads the 2D vector tile <vector_size x vector_size> from the shared memory
// at the given indices.
Expand Down
17 changes: 10 additions & 7 deletions xla/backends/gpu/runtime/all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/backends/gpu/runtime/all_to_all_thunk.h"

#include <atomic>
#include <cstdint>
#include <cstdlib>
#include <iterator>
Expand Down Expand Up @@ -116,10 +117,11 @@ AllToAllStartThunk::AllToAllStartThunk(

absl::Status AllToAllStartThunk::Initialize(const InitializeParams& params) {
TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params));
device_count_ = params.local_device_count;
CHECK_GT(device_count_, 0);
VLOG(5) << "[" << params.executor->device_ordinal()
<< "] Local device count : " << device_count_;
device_count_.store(params.local_device_count, std::memory_order_relaxed);
CHECK_GT(params.local_device_count, 0);
VLOG(5) << "Local device count: " << params.local_device_count;

TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));

if (is_local() && p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
Expand Down Expand Up @@ -254,11 +256,12 @@ AsyncStreamKind AllToAllStartThunk::GetAsyncStreamKind() const {
}

bool AllToAllStartThunk::is_local() const {
const auto device_count = device_count_.load(std::memory_order_relaxed);
for (const auto& replica_group : config_.config.replica_groups) {
const int64_t node_id = replica_group.replica_ids().at(0) / device_count_;
const int64_t node_id = replica_group.replica_ids().at(0) / device_count;
if (!absl::c_all_of(replica_group.replica_ids(),
[this, node_id](const int64_t rank) {
return rank / device_count_ == node_id;
[node_id, device_count](const int64_t rank) {
return rank / device_count == node_id;
})) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/runtime/all_to_all_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AllToAllStartThunk : public CollectiveThunk {
private:
const AllToAllConfig config_;
const std::vector<Buffer> buffers_;
int64_t device_count_ = 1;
std::atomic<int64_t> device_count_ = 1;
bool p2p_memcpy_enabled_ = false;

absl::Mutex pointer_maps_mutex_;
Expand Down
10 changes: 6 additions & 4 deletions xla/backends/gpu/runtime/collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/collective_permute_thunk.h"

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <iterator>
Expand Down Expand Up @@ -182,9 +183,9 @@ CollectivePermuteStartThunk::CollectivePermuteStartThunk(
absl::Status CollectivePermuteStartThunk::Initialize(
const InitializeParams& params) {
TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params));
device_count_ = params.local_device_count;
CHECK_GT(device_count_, 0);
VLOG(5) << "Local device count: " << device_count_;
device_count_.store(params.local_device_count, std::memory_order_relaxed);
CHECK_GT(params.local_device_count, 0);
VLOG(5) << "Local device count: " << params.local_device_count;

if (p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
Expand Down Expand Up @@ -259,8 +260,9 @@ absl::StatusOr<bool> CollectivePermuteStartThunk::RunCollective(

const P2PConfig::SourceTargetMapEntry source_target =
P2PConfig::GetSourceTarget(config_.id_to_source_target, current_id);
const auto device_count = device_count_.load(std::memory_order_relaxed);
bool is_local_peer =
IsLocalPeerTransfer(source_target, current_id, device_count_);
IsLocalPeerTransfer(source_target, current_id, device_count);
VLOG(5) << "Is local peer : " << (is_local_peer ? "true" : "false");

bool use_memcpy = is_local_peer && recv_ptr_map_.IsInitialized(current_id) &&
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/runtime/collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class CollectivePermuteStartThunk : public CollectiveThunk {
sender_barrier_events_;

bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
std::atomic<int64_t> device_count_;
};

absl::Status RunCollectivePermute(
Expand Down
11 changes: 6 additions & 5 deletions xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ RaggedAllToAllStartThunk::RaggedAllToAllStartThunk(
absl::Status RaggedAllToAllStartThunk::Initialize(
const InitializeParams& params) {
TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params));
device_count_ = params.local_device_count;
device_count_.store(params.local_device_count, std::memory_order_relaxed);

se::StreamExecutor* executor = params.executor;

Expand Down Expand Up @@ -511,12 +511,13 @@ absl::Status RaggedAllToAllStartThunk::Initialize(
}

bool RaggedAllToAllStartThunk::is_local() const {
CHECK_NE(device_count_, -1);
const auto device_count = device_count_.load(std::memory_order_relaxed);
CHECK_NE(device_count, -1);
for (const auto& replica_group : config_.config.replica_groups) {
const int64_t node_id = replica_group.replica_ids().at(0) / device_count_;
const int64_t node_id = replica_group.replica_ids().at(0) / device_count;
if (!absl::c_all_of(replica_group.replica_ids(),
[this, node_id](const int64_t rank) {
return rank / device_count_ == node_id;
[node_id, device_count](const int64_t rank) {
return rank / device_count == node_id;
})) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/runtime/ragged_all_to_all_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class RaggedAllToAllStartThunk : public CollectiveThunk {

const RaggedAllToAllConfig config_;
const std::vector<Buffer> buffers_;
int64_t device_count_ = -1;
std::atomic<int64_t> device_count_ = -1;
const bool p2p_memcpy_enabled_;
const bool one_shot_kernel_enabled_;

Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ cc_library(
],
)

tf_proto_library(
name = "triton_call_args_proto",
srcs = ["triton_call_args.proto"],
)

cc_library(
name = "kernel_call",
srcs = ["kernel_call.cc"],
Expand Down Expand Up @@ -655,11 +660,14 @@ cc_library(
srcs = ["triton_call.cc"],
hdrs = ["triton_call.h"],
deps = [
":triton_call_args_proto_cc",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:protobuf",
],
)

Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
stream_executor::RocmSolverContext::Create);
pipeline.AddPass<ConvRewriter>(gpu_version);
pipeline.AddPass<ConvPaddingLegalization>();
auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version, toolkit_version);
//TODO(rocm): Until #12613 is fixed.
// auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
// pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version, toolkit_version);

// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
Expand Down
7 changes: 7 additions & 0 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,13 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
sanitized_kernel_name, kernel_arguments,
launch_dimensions, &builder));

// If value for waves_per_eu is given create corresponding ROCm func attr
if (call.waves_per_eu != 0) {
// Default value - same as no value is given.
kernel->addFnAttr("amdgpu-waves-per-eu",
std::to_string(call.waves_per_eu));
}

// Move function body into kernel prototype.
llvm::Function* prototype_func = builder.GetInsertBlock()->getParent();
prototype_func->splice(prototype_func->begin(), impl_fn);
Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,18 @@ cc_library(
xla_test(
name = "cudnn_fused_conv_rewriter_test",
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
tags = [
"cuda-only", # TODO(rocm): Until #12613 is fixed.
],
backend_tags = {
"gpu_a100": [
"noasan",
"nomsan",
],
},
backends = [
"a100",
"amdgpu_any",
] + if_oss(["nvgpu_any"]),
"gpu_a100",
],
shard_count = 10,
deps = [
":conv_rewriter",
Expand Down
18 changes: 17 additions & 1 deletion xla/service/gpu/triton_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/LLVM.h"
#include "xla/service/gpu/triton_call_args.pb.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/logging.h"

namespace xla::gpu {

Expand All @@ -44,7 +47,20 @@ TritonCall TritonCall::Parse(absl::string_view backend_config,
attrs.getAs<mlir::IntegerAttr>("num_stages").getValue().getSExtValue();
auto num_warps =
attrs.getAs<mlir::IntegerAttr>("num_warps").getValue().getSExtValue();
return TritonCall{std::move(name), std::move(ir), num_stages, num_warps,
auto attr_smd = attrs.getAs<mlir::StringAttr>("serialized_metadata");
int64_t waves_per_eu = 0;
if (attr_smd) {
TritonCallArgs triton_call_args_proto;
auto sermetadata = attr_smd.getValue().str();
if (tsl::protobuf::TextFormat::ParseFromString(
sermetadata, &triton_call_args_proto)) {
waves_per_eu = triton_call_args_proto.waves_per_eu();
} else {
// Parsing error: set default value
waves_per_eu = 0;
}
}
return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, waves_per_eu,
grid_x, grid_y, grid_z};
}

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/triton_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct TritonCall {
std::string ir;
int64_t num_stages;
int64_t num_warps;
int64_t waves_per_eu;
int32_t grid_x;
int32_t grid_y;
int32_t grid_z;
Expand Down
9 changes: 9 additions & 0 deletions xla/service/gpu/triton_call_args.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package xla.gpu;

// Arguments for triton calls for XLA:GPU.

message TritonCallArgs {
optional int32 waves_per_eu = 1;
}
3 changes: 2 additions & 1 deletion xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace wrap {
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char *kName = TO_STR(hipSymbolName); \
void *f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \
&f); \
CHECK(s.ok()) << "could not find " << kName \
Expand Down Expand Up @@ -100,6 +100,7 @@ namespace wrap {
__macro(hipGetDeviceCount) \
__macro(hipGetDeviceProperties) \
__macro(hipGetErrorString) \
__macro(hipGetLastError) \
__macro(hipGraphAddKernelNode) \
__macro(hipGraphAddChildGraphNode) \
__macro(hipGraphAddEmptyNode) \
Expand Down
9 changes: 7 additions & 2 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,16 @@ absl::Status EnablePeerAccess(Context* from, Context* to) {
hipError_t result =
wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);

if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
if (result == hipErrorPeerAccessAlreadyEnabled) {
// hipGetLastError is used to reset per thread error state,
// as hipGetLastError would get the recent error code since rocm7 even the
// last call is successful.
(void)wrap::hipGetLastError();
} else if (result != hipSuccess) {
return absl::InternalError(
absl::StrFormat("failed to enable peer access from %d to %d: %s",
from->device_ordinal(), to->device_ordinal(),
ToString(result).c_str()));
wrap::hipGetErrorString(result)));
}

return absl::OkStatus();
Expand Down
Loading