Skip to content

Commit 9c56a0c

Browse files
committed
Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability
The PR enables use of BF16 for dot algorithms on Navi3x and Navi4x, and bugfixes and optimizes class RocmComputeCapability. Syncronizes the files with ROCm/xla#303 into `rocm-jaxlib-v0.6.0` branch
1 parent 99bb1c2 commit 9c56a0c

File tree

3 files changed

+96
-51
lines changed

3 files changed

+96
-51
lines changed

third_party/xla/xla/service/algorithm_util.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ bool IsSupportedByElementalIrEmitter(PrecisionConfig::Algorithm algorithm) {
173173
// input/output storage types.
174174
bool IsSupportedDotAlgorithmOnGpu(
175175
PrecisionConfig::Algorithm algorithm,
176-
stream_executor::GpuComputeCapability gpu_compute_capability,
176+
const stream_executor::GpuComputeCapability& gpu_compute_capability,
177177
PrimitiveType input_storage_type, PrimitiveType output_storage_type) {
178178
// Note: We may want to add some complex types here if people request that.
179179
const bool is_cuda_ge_ampere =
@@ -194,6 +194,12 @@ bool IsSupportedDotAlgorithmOnGpu(
194194
std::get<se::RocmComputeCapability>(gpu_compute_capability)
195195
.gfx9_mi100_or_later();
196196

197+
const bool is_rocm_bf16 =
198+
std::holds_alternative<se::RocmComputeCapability>(
199+
gpu_compute_capability) &&
200+
std::get<se::RocmComputeCapability>(gpu_compute_capability)
201+
.has_bf16_dtype_support();
202+
197203
switch (algorithm) {
198204
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
199205
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
@@ -207,7 +213,7 @@ bool IsSupportedDotAlgorithmOnGpu(
207213
return input_storage_type == F16 &&
208214
(output_storage_type == F16 || output_storage_type == F32);
209215
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
210-
if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false;
216+
if (!is_cuda_ge_ampere && !is_rocm_bf16) return false;
211217
switch (input_storage_type) {
212218
case BF16:
213219
return output_storage_type == BF16 || output_storage_type == F32;
@@ -218,7 +224,7 @@ bool IsSupportedDotAlgorithmOnGpu(
218224
}
219225
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
220226
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
221-
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
227+
return (is_cuda_ge_ampere || is_rocm_bf16) &&
222228
input_storage_type == F32 && output_storage_type == F32;
223229
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
224230
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:

third_party/xla/xla/service/algorithm_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ bool IsSupportedByElementalIrEmitter(PrecisionConfig::Algorithm algorithm);
7070
// input/output storage types.
7171
bool IsSupportedDotAlgorithmOnGpu(
7272
PrecisionConfig::Algorithm algorithm,
73-
stream_executor::GpuComputeCapability gpu_compute_capability,
73+
const stream_executor::GpuComputeCapability& gpu_compute_capability,
7474
PrimitiveType input_storage_type, PrimitiveType output_storage_type);
7575

7676
} // namespace algorithm_util

third_party/xla/xla/stream_executor/device_description.h

Lines changed: 86 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ limitations under the License.
2020
#ifndef XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2121
#define XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2222

23+
#include <algorithm>
2324
#include <cassert>
2425
#include <cstdint>
26+
#include <cstring>
2527
#include <string>
2628
#include <type_traits>
2729
#include <utility>
2830
#include <variant>
2931
#include <vector>
3032

3133
#include "absl/algorithm/container.h"
34+
#include "absl/strings/match.h"
3235
#include "absl/strings/str_join.h"
3336
#include "absl/strings/str_split.h"
3437
#include "absl/strings/string_view.h"
@@ -54,13 +57,45 @@ class RocmComputeCapability {
5457

5558
std::string gcn_arch_name() const { return gcn_arch_name_; }
5659

60+
std::string ToString() const { return gcn_arch_name(); }
61+
62+
RocmComputeCapabilityProto ToProto() const {
63+
RocmComputeCapabilityProto proto;
64+
proto.set_gcn_arch_name(gcn_arch_name_);
65+
return proto;
66+
}
67+
68+
bool operator==(const RocmComputeCapability &other) const {
69+
return gcn_arch_name_ == other.gcn_arch_name_;
70+
}
71+
5772
std::string gfx_version() const {
58-
std::vector<std::string> tokens = absl::StrSplit(gcn_arch_name_, ':');
59-
return tokens[0];
73+
// std::strchr() is faster for the case than std::string::find()
74+
const char *const p_colon = std::strchr(gcn_arch_name_.c_str(), ':');
75+
if (nullptr == p_colon) {
76+
return gcn_arch_name_; // likely it's the default invalid value
77+
}
78+
return std::string(gcn_arch_name_.c_str(), p_colon);
6079
}
6180

81+
// note, while there's no particular reason to make the lists public, it won't
82+
// hurt since they are immutable, but keeping them close to methods simplifies
83+
// maintanance.
84+
static constexpr absl::string_view kSupportedGfxVersions[]{
85+
"gfx900", // MI25
86+
"gfx906", // MI50 / MI60
87+
"gfx908", // MI100
88+
"gfx90a", // MI200
89+
"gfx942", // MI300
90+
"gfx950", // MI350
91+
"gfx1030", // RX68xx / RX69xx
92+
"gfx1100", // RX7900
93+
"gfx1101", // RX7700 / RX7800
94+
"gfx1103", "gfx1150", "gfx1151", "gfx1200", "gfx1201",
95+
};
96+
6297
bool is_supported_gfx_version() const {
63-
return absl::c_count(kSupportedGfxVersions, gfx_version()) != 0;
98+
return IsThisGfxInAnyList(kSupportedGfxVersions);
6499
}
65100

66101
std::string supported_gfx_versions_str() const {
@@ -69,64 +104,73 @@ class RocmComputeCapability {
69104

70105
bool gfx9_mi100() const { return gfx_version() == "gfx908"; }
71106

107+
static constexpr absl::string_view kMI100Series[] = {"gfx908"};
108+
72109
bool gfx9_mi200() const { return gfx_version() == "gfx90a"; }
73110

111+
static constexpr absl::string_view kMI200Series[] = {"gfx90a"};
112+
74113
bool gfx9_mi300() const { return gfx_version() == "gfx942"; }
75114

76115
bool gfx9_mi350() const { return gfx_version() == "gfx950"; }
77116

117+
static constexpr absl::string_view kMI300Series[] = {"gfx942", "gfx950"};
118+
bool gfx9_mi300_series() const { return IsThisGfxInAnyList(kMI300Series); }
119+
78120
bool gfx9_mi100_or_later() const {
79-
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx942", "gfx950"};
80-
return absl::c_count(kList, gfx_version()) != 0;
121+
return IsThisGfxInAnyList(kMI300Series, kMI200Series, kMI100Series);
81122
}
82123

83124
bool gfx9_mi200_or_later() const {
84-
static constexpr absl::string_view kList[] = {"gfx90a", "gfx942", "gfx950"};
85-
return absl::c_count(kList, gfx_version()) != 0;
125+
return IsThisGfxInAnyList(kMI300Series, kMI200Series);
86126
}
87127

88128
bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; }
89129

90130
bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; }
91131

92-
bool gfx11_rx7900() const { return (gfx_version() == "gfx1100" ||
93-
gfx_version() == "gfx1101" ||
94-
gfx_version() == "gfx1102"); }
132+
bool gfx11() const { return absl::StartsWith(gfx_version(), "gfx11"); }
95133

96-
bool gfx12_rx8900() const { return ((gfx_version() == "gfx1200") ||
97-
(gfx_version() == "gfx1201")); }
134+
static constexpr absl::string_view kGfx11Discrete[] = {"gfx1100", "gfx1101"};
135+
bool gfx11_discrete() const { return IsThisGfxInAnyList(kGfx11Discrete); }
98136

99-
bool gfx1200() const { return gfx_version() == "gfx1200"; }
137+
static constexpr absl::string_view kGfx11Apu[] = {"gfx1103", "gfx1150",
138+
"gfx1151"};
139+
bool gfx11_apu() const { return IsThisGfxInAnyList(kGfx11Apu); }
100140

101-
bool gfx1201() const { return gfx_version() == "gfx1201"; }
141+
bool gfx12() const { return absl::StartsWith(gfx_version(), "gfx12"); }
142+
143+
static constexpr absl::string_view kGfx12Discrete[] = {"gfx1200", "gfx1201"};
144+
bool gfx12_discrete() const { return IsThisGfxInAnyList(kGfx12Discrete); }
102145

103146
bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); }
104147

105-
bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); }
148+
bool has_bf16_dtype_support() const {
149+
return gfx9_mi100_or_later() || gfx12() || gfx11();
150+
}
106151

107152
bool has_fast_fp16_support() const {
108-
return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() ||
109-
gfx11_rx7900();
153+
return gfx9_mi100_or_later() || gfx11() || gfx10_rx68xx() || gfx10_rx69xx();
110154
}
111155

112156
bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }
113157

114158
bool has_amd_matrix_core() const {
115-
return (gfx9_mi100_or_later() || gfx_version().find("gfx11") ||
116-
gfx_version().find("gfx12"));
159+
return gfx9_mi100_or_later() || gfx12() || gfx11();
117160
}
118161

119-
bool has_fp16_atomics_support() const {
120-
// TODO(rocm): Check. This should be the same as has_fast_fp16_support().
121-
return gfx9_mi200_or_later();
122-
}
162+
bool has_packed_fp16_atomics_support() const { return gfx9_mi100_or_later(); }
163+
164+
bool has_packed_bf16_atomics_support() const { return gfx9_mi300_series(); }
123165

124166
bool fence_before_barrier() const {
125-
return gfx_version() != "gfx900" && gfx_version() != "gfx906";
167+
static constexpr absl::string_view kList[] = {"gfx900", "gfx906"};
168+
return !IsThisGfxInAnyList(kList);
126169
}
127170

128171
bool has_hipblaslt() const {
129-
return gfx9_mi200_or_later() || gfx1200() || gfx1201();
172+
return IsThisGfxInAnyList(kMI300Series, kMI200Series, kGfx12Discrete,
173+
kGfx11Discrete, kGfx11Apu);
130174
}
131175

132176
bool has_hipblaslt_mx_support() const { return gfx9_mi350(); }
@@ -135,36 +179,31 @@ class RocmComputeCapability {
135179
return has_ocp_fp8_support() || has_nanoo_fp8_support();
136180
}
137181

138-
bool has_ocp_fp8_support() const { return gfx1200() || gfx1201() || gfx9_mi350(); }
182+
bool has_ocp_fp8_support() const { return gfx9_mi350() || gfx12_discrete(); }
139183

140184
bool has_nanoo_fp8_support() const { return gfx9_mi300(); }
141185

142-
std::string ToString() const { return gcn_arch_name(); }
143-
144-
RocmComputeCapabilityProto ToProto() const {
145-
RocmComputeCapabilityProto proto;
146-
proto.set_gcn_arch_name(gcn_arch_name_);
147-
return proto;
186+
private:
187+
/// \brief Takes one or more arrays of string-like objects and tests if the
188+
/// result of `gfx_version()` matches to any string in any of the arrays.
189+
template <typename... ArrayOfStrings>
190+
bool IsThisGfxInAnyList(ArrayOfStrings &&...arr) const {
191+
static_assert(sizeof...(arr) >= 1);
192+
const auto gfx = gfx_version();
193+
return (implIsThisGfxInAnyList(std::begin(arr), std::end(arr), gfx) || ...);
148194
}
149195

150-
bool operator==(const RocmComputeCapability &other) const {
151-
return gcn_arch_name_ == other.gcn_arch_name_;
196+
/// \brief Template-less implementation of IsThisGfxInAnyList().
197+
/// \warning Don't use directly!
198+
bool implIsThisGfxInAnyList(const absl::string_view *beg,
199+
const absl::string_view *end,
200+
const std::string &gfx) const {
201+
return std::any_of(beg, end, [&gfx = gfx](const absl::string_view &s) {
202+
return gfx == s;
203+
});
152204
}
153205

154-
private:
155206
std::string gcn_arch_name_ = "gfx000"; // default to invalid arch.
156-
157-
static constexpr absl::string_view kSupportedGfxVersions[]{
158-
"gfx900",
159-
"gfx906",
160-
"gfx908",
161-
"gfx90a",
162-
"gfx942",
163-
"gfx950",
164-
"gfx1030",
165-
"gfx1100", "gfx1101", "gfx1102",
166-
"gfx1200", "gfx1201",
167-
};
168207
};
169208

170209
using GpuComputeCapability =

0 commit comments

Comments
 (0)