@@ -20,15 +20,18 @@ limitations under the License.
2020#ifndef XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2121#define XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2222
23+ #include < algorithm>
2324#include < cassert>
2425#include < cstdint>
26+ #include < cstring>
2527#include < string>
2628#include < type_traits>
2729#include < utility>
2830#include < variant>
2931#include < vector>
3032
3133#include " absl/algorithm/container.h"
34+ #include " absl/strings/match.h"
3235#include " absl/strings/str_join.h"
3336#include " absl/strings/str_split.h"
3437#include " absl/strings/string_view.h"
@@ -54,13 +57,45 @@ class RocmComputeCapability {
5457
5558 std::string gcn_arch_name () const { return gcn_arch_name_; }
5659
60+ std::string ToString () const { return gcn_arch_name (); }
61+
62+ RocmComputeCapabilityProto ToProto () const {
63+ RocmComputeCapabilityProto proto;
64+ proto.set_gcn_arch_name (gcn_arch_name_);
65+ return proto;
66+ }
67+
68+ bool operator ==(const RocmComputeCapability &other) const {
69+ return gcn_arch_name_ == other.gcn_arch_name_ ;
70+ }
71+
5772 std::string gfx_version () const {
58- std::vector<std::string> tokens = absl::StrSplit (gcn_arch_name_, ' :' );
59- return tokens[0 ];
73+ // std::strchr() is faster for the case than std::string::find()
74+ const char *const p_colon = std::strchr (gcn_arch_name_.c_str (), ' :' );
75+ if (nullptr == p_colon) {
76+ return gcn_arch_name_; // likely it's the default invalid value
77+ }
78+ return std::string (gcn_arch_name_.c_str (), p_colon);
6079 }
6180
81+ // note, while there's no particular reason to make the lists public, it won't
82+ // hurt since they are immutable, but keeping them close to methods simplifies
83+ // maintanance.
84+ static constexpr absl::string_view kSupportedGfxVersions []{
85+ " gfx900" , // MI25
86+ " gfx906" , // MI50 / MI60
87+ " gfx908" , // MI100
88+ " gfx90a" , // MI200
89+ " gfx942" , // MI300
90+ " gfx950" , // MI350
91+ " gfx1030" , // RX68xx / RX69xx
92+ " gfx1100" , // RX7900
93+ " gfx1101" , // RX7700 / RX7800
94+ " gfx1103" , " gfx1150" , " gfx1151" , " gfx1200" , " gfx1201" ,
95+ };
96+
6297 bool is_supported_gfx_version () const {
63- return absl::c_count (kSupportedGfxVersions , gfx_version ()) != 0 ;
98+ return IsThisGfxInAnyList (kSupportedGfxVersions ) ;
6499 }
65100
66101 std::string supported_gfx_versions_str () const {
@@ -69,64 +104,73 @@ class RocmComputeCapability {
69104
70105 bool gfx9_mi100 () const { return gfx_version () == " gfx908" ; }
71106
107+ static constexpr absl::string_view kMI100Series [] = {" gfx908" };
108+
72109 bool gfx9_mi200 () const { return gfx_version () == " gfx90a" ; }
73110
111+ static constexpr absl::string_view kMI200Series [] = {" gfx90a" };
112+
74113 bool gfx9_mi300 () const { return gfx_version () == " gfx942" ; }
75114
76115 bool gfx9_mi350 () const { return gfx_version () == " gfx950" ; }
77116
117+ static constexpr absl::string_view kMI300Series [] = {" gfx942" , " gfx950" };
118+ bool gfx9_mi300_series () const { return IsThisGfxInAnyList (kMI300Series ); }
119+
78120 bool gfx9_mi100_or_later () const {
79- static constexpr absl::string_view kList [] = {" gfx908" , " gfx90a" , " gfx942" , " gfx950" };
80- return absl::c_count (kList , gfx_version ()) != 0 ;
121+ return IsThisGfxInAnyList (kMI300Series , kMI200Series , kMI100Series );
81122 }
82123
83124 bool gfx9_mi200_or_later () const {
84- static constexpr absl::string_view kList [] = {" gfx90a" , " gfx942" , " gfx950" };
85- return absl::c_count (kList , gfx_version ()) != 0 ;
125+ return IsThisGfxInAnyList (kMI300Series , kMI200Series );
86126 }
87127
88128 bool gfx10_rx68xx () const { return gfx_version () == " gfx1030" ; }
89129
90130 bool gfx10_rx69xx () const { return gfx_version () == " gfx1030" ; }
91131
92- bool gfx11_rx7900 () const { return (gfx_version () == " gfx1100" ||
93- gfx_version () == " gfx1101" ||
94- gfx_version () == " gfx1102" ); }
132+ bool gfx11 () const { return absl::StartsWith (gfx_version (), " gfx11" ); }
95133
96- bool gfx12_rx8900 () const { return (( gfx_version () == " gfx1200 " ) ||
97- ( gfx_version () == " gfx1201 " ) ); }
134+ static constexpr absl::string_view kGfx11Discrete [] = { " gfx1100 " , " gfx1101 " };
135+ bool gfx11_discrete () const { return IsThisGfxInAnyList ( kGfx11Discrete ); }
98136
99- bool gfx1200 () const { return gfx_version () == " gfx1200" ; }
137+ static constexpr absl::string_view kGfx11Apu [] = {" gfx1103" , " gfx1150" ,
138+ " gfx1151" };
139+ bool gfx11_apu () const { return IsThisGfxInAnyList (kGfx11Apu ); }
100140
101- bool gfx1201 () const { return gfx_version () == " gfx1201" ; }
141+ bool gfx12 () const { return absl::StartsWith (gfx_version (), " gfx12" ); }
142+
143+ static constexpr absl::string_view kGfx12Discrete [] = {" gfx1200" , " gfx1201" };
144+ bool gfx12_discrete () const { return IsThisGfxInAnyList (kGfx12Discrete ); }
102145
103146 bool has_nhwc_layout_support () const { return gfx9_mi100_or_later (); }
104147
105- bool has_bf16_dtype_support () const { return gfx9_mi100_or_later (); }
148+ bool has_bf16_dtype_support () const {
149+ return gfx9_mi100_or_later () || gfx12 () || gfx11 ();
150+ }
106151
107152 bool has_fast_fp16_support () const {
108- return gfx9_mi100_or_later () || gfx10_rx68xx () || gfx10_rx69xx () ||
109- gfx11_rx7900 ();
153+ return gfx9_mi100_or_later () || gfx11 () || gfx10_rx68xx () || gfx10_rx69xx ();
110154 }
111155
112156 bool has_mfma_instr_support () const { return gfx9_mi100_or_later (); }
113157
114158 bool has_amd_matrix_core () const {
115- return (gfx9_mi100_or_later () || gfx_version ().find (" gfx11" ) ||
116- gfx_version ().find (" gfx12" ));
159+ return gfx9_mi100_or_later () || gfx12 () || gfx11 ();
117160 }
118161
119- bool has_fp16_atomics_support () const {
120- // TODO(rocm): Check. This should be the same as has_fast_fp16_support().
121- return gfx9_mi200_or_later ();
122- }
162+ bool has_packed_fp16_atomics_support () const { return gfx9_mi100_or_later (); }
163+
164+ bool has_packed_bf16_atomics_support () const { return gfx9_mi300_series (); }
123165
124166 bool fence_before_barrier () const {
125- return gfx_version () != " gfx900" && gfx_version () != " gfx906" ;
167+ static constexpr absl::string_view kList [] = {" gfx900" , " gfx906" };
168+ return !IsThisGfxInAnyList (kList );
126169 }
127170
128171 bool has_hipblaslt () const {
129- return gfx9_mi200_or_later () || gfx1200 () || gfx1201 ();
172+ return IsThisGfxInAnyList (kMI300Series , kMI200Series , kGfx12Discrete ,
173+ kGfx11Discrete , kGfx11Apu );
130174 }
131175
132176 bool has_hipblaslt_mx_support () const { return gfx9_mi350 (); }
@@ -135,36 +179,31 @@ class RocmComputeCapability {
135179 return has_ocp_fp8_support () || has_nanoo_fp8_support ();
136180 }
137181
138- bool has_ocp_fp8_support () const { return gfx1200 () || gfx1201 () || gfx9_mi350 (); }
182+ bool has_ocp_fp8_support () const { return gfx9_mi350 () || gfx12_discrete (); }
139183
140184 bool has_nanoo_fp8_support () const { return gfx9_mi300 (); }
141185
142- std::string ToString () const { return gcn_arch_name (); }
143-
144- RocmComputeCapabilityProto ToProto () const {
145- RocmComputeCapabilityProto proto;
146- proto.set_gcn_arch_name (gcn_arch_name_);
147- return proto;
186+ private:
187+ // / \brief Takes one or more arrays of string-like objects and tests if the
188+ // / result of `gfx_version()` matches to any string in any of the arrays.
189+ template <typename ... ArrayOfStrings>
190+ bool IsThisGfxInAnyList (ArrayOfStrings &&...arr) const {
191+ static_assert (sizeof ...(arr) >= 1 );
192+ const auto gfx = gfx_version ();
193+ return (implIsThisGfxInAnyList (std::begin (arr), std::end (arr), gfx) || ...);
148194 }
149195
150- bool operator ==(const RocmComputeCapability &other) const {
151- return gcn_arch_name_ == other.gcn_arch_name_ ;
196+ // / \brief Template-less implementation of IsThisGfxInAnyList().
197+ // / \warning Don't use directly!
198+ bool implIsThisGfxInAnyList (const absl::string_view *beg,
199+ const absl::string_view *end,
200+ const std::string &gfx) const {
201+ return std::any_of (beg, end, [&gfx = gfx](const absl::string_view &s) {
202+ return gfx == s;
203+ });
152204 }
153205
154- private:
155206 std::string gcn_arch_name_ = " gfx000" ; // default to invalid arch.
156-
157- static constexpr absl::string_view kSupportedGfxVersions []{
158- " gfx900" ,
159- " gfx906" ,
160- " gfx908" ,
161- " gfx90a" ,
162- " gfx942" ,
163- " gfx950" ,
164- " gfx1030" ,
165- " gfx1100" , " gfx1101" , " gfx1102" ,
166- " gfx1200" , " gfx1201" ,
167- };
168207};
169208
170209using GpuComputeCapability =
0 commit comments