@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
7878
7979// Utility functions for type conversions.
8080inline __device__ float2 bf1622float2 (const __nv_bfloat162 val) {
81+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
82+ assert (false );
83+ #else
8184 return __bfloat1622float2 (val);
85+ #endif
8286}
8387
8488inline __device__ __nv_bfloat162 bf162bf162 (const __nv_bfloat16 val) {
89+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
90+ assert (false );
91+ #else
8592 return __bfloat162bfloat162 (val);
93+ #endif
8694}
8795
8896// Vector addition.
8997inline __device__ __nv_bfloat16 add (__nv_bfloat16 a, __nv_bfloat16 b) {
98+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99+ assert (false );
100+ #else
90101 return a + b;
102+ #endif
91103}
92104
93105inline __device__ __nv_bfloat162 add (__nv_bfloat162 a, __nv_bfloat162 b) {
106+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
107+ assert (false );
108+ #else
94109 return __hadd2 (a, b);
110+ #endif
95111}
96112
97113inline __device__ bf16_4_t add (bf16_4_t a, bf16_4_t b) {
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
134150// Vector multiplication.
135151template <>
136152inline __device__ __nv_bfloat16 mul (__nv_bfloat16 a, __nv_bfloat16 b) {
153+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
154+ assert (false );
155+ #else
137156 return __hmul (a, b);
157+ #endif
138158}
139159
140160template <>
141161inline __device__ __nv_bfloat162 mul (__nv_bfloat162 a, __nv_bfloat162 b) {
162+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
163+ assert (false );
164+ #else
142165 return __hmul2 (a, b);
166+ #endif
143167}
144168
145169template <>
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
244268
245269// Vector fused multiply-add.
246270inline __device__ __nv_bfloat162 fma (__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
271+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
272+ assert (false );
273+ #else
247274 return __hfma2 (a, b, c);
275+ #endif
248276}
249277
250278inline __device__ __nv_bfloat162 fma (__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
279+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
280+ assert (false );
281+ #else
251282 return __hfma2 (bf162bf162 (a), b, c);
283+ #endif
252284}
253285
254286inline __device__ bf16_4_t fma (bf16_4_t a, bf16_4_t b, bf16_4_t c) {
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
361393}
362394
363395inline __device__ void from_float (__nv_bfloat162& dst, float2 src) {
396+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
397+ assert (false );
398+ #else
364399 dst = __float22bfloat162_rn (src);
400+ #endif
365401}
366402
367403inline __device__ void from_float (bf16_4_t & dst, Float4_ src) {
404+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
405+ assert (false );
406+ #else
368407 dst.x = __float22bfloat162_rn (src.x );
369408 dst.y = __float22bfloat162_rn (src.y );
409+ #endif
370410}
371411
372412inline __device__ void from_float (bf16_8_t & dst, Float8_ src) {
413+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
414+ assert (false );
415+ #else
373416 dst.x = __float22bfloat162_rn (src.x );
374417 dst.y = __float22bfloat162_rn (src.y );
375418 dst.z = __float22bfloat162_rn (src.z );
376419 dst.w = __float22bfloat162_rn (src.w );
420+ #endif
377421}
378422
379423} // namespace cacheflow
0 commit comments