diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 03bde86421..a9ae0b2a6a 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 4f174bfcbb..791d81e264 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index d82b56ec00..2a972c13eb 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -77,7 +77,7 @@ inline __host__ __device__ constexpr double get_atol() { if constexpr(std::is_same_v && std::is_same_v) { - return 1e-2; + return 1e-3; } else if constexpr(std::is_same_v) { diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 20cbc5fdca..20d9bab7e1 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() + +list(APPEND gpu_list_tf32 gfx942 gfx950) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32) + set(target 1) + endif() +endforeach() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..c9a3ede151 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 87ccebc3c4..62f0f3673d 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + struct ProblemSize final { std::vector Ms; @@ -231,7 +236,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co AccDataType, AElementOp, BElementOp, - CDEElementOp>; + CDEElementOp, + ComputeDataType>; for(std::size_t i = 0; i < gemm_descs.size(); i++) { @@ -253,7 +259,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); #endif } } diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 0c4f056a46..53f4c27399 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,7 +129,10 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } -inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } +inline bool is_tf32_supported() +{ + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; +} } // namespace ck #endif diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 3637053e14..fccd5c8e75 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -168,8 +168,8 @@ typename std::enable_if< check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-5, - double atol = 3e-5) + double rtol = 5e-4, + double atol = 5e-4) { if(out.size() != ref.size()) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 52632785bd..cf9992942e 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -94,7 +94,8 @@ template + typename CElementwiseOperation, + typename ComputeDataType = ADataType> struct DeviceGroupedGemm : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 7a1944cc68..0ae1aa321a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -134,7 +134,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeDataType = ADataType> struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceGroupedGemm_Xdl; GET_NXDL_PER_WAVE_IMPL @@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); - using ComputeDataType = ADataType; - // GridwiseGemm template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index ce2d9299f9..0817cf9856 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -80,8 +80,10 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, - mfma_f32_16x16x8xf32, // tf32 - mfma_f32_32x32x4xf32, + mfma_f32_16x16x8xf32, // tf32 on gfx942 + mfma_f32_32x32x4xf32, // tf32 on gfx942 + mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950 + mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950 // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -1015,6 +1017,51 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16xf32::Run(a, b, reg_c); + } +}; +template <> +struct mfma_type +{ + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32xf32::Run(a, b, reg_c); + } +}; + // gfx11 struct mfma_type_gfx11_base { @@ -1275,12 +1322,14 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_32x32x4xf32; #else @@ -1289,12 +1338,14 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_16x16x8xf32; #else @@ -2185,6 +2236,10 @@ struct XdlopsGemm (is_same::value && KPack <= 8) || ((is_same::value || is_same::value) && KPack < 32) || is_same::value) +#if defined(__gfx950__) + // tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16. + || (is_same::value && KPack <= 4) +#endif ? true : false; static constexpr auto mfma = MfmaSelector +__device__ __forceinline__ void +convert_float_to_bf16_pairs(const vector_type& reg_f32, + vector_type& reg_bf16_big, + vector_type& reg_bf16_small) +{ + static_for<0, VecSize, 1>{}([&](auto k) { + using IK = Number; + reg_bf16_big.template AsType()(k) = + type_convert(reg_f32.template AsType()[IK{}]); + reg_bf16_small.template AsType()(k) = type_convert( + reg_f32.template AsType()[IK{}] - + type_convert(reg_bf16_big.template AsType()[IK{}])); + }); +} +/* */ + // fp32 template struct intrin_mfma_f32_32x32x1f32; @@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; -/******************* tf32 *************************************/ +/******************* tf32 on gfx942 *************************************/ template struct intrin_mfma_f32_16x16x8xf32; @@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> } }; +/******************* tf32/xf32 on gfx950 ********************************/ +/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */ +/* step: */ +/* 1. separate one input to 2 bf16 registers: */ +/* in_bf16_big = f32_to_bf16(in_f32) */ +/* in_bf16_small = in_f32 - in_bf16_big */ +/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */ +/* out_f32 = A_bf16_big * B_bf16_big */ +/* out_f32 += A_bf16_small * B_bf16_big */ +/* out_f32 += A_bf16_big * B_bf16_small */ +/************************************************************************/ +template +struct intrin_mfma_f32_16x16x32xf32; + +template <> +struct intrin_mfma_f32_16x16x32xf32<16, 16> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); + + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_32x32x16xf32; + +template <> +struct intrin_mfma_f32_32x32x16xf32<32, 32> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); + + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +/******************* tf32/xf32 on gfx950 end ************************************/ } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 573571bc07..f47ce05cac 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -14,6 +14,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/host_utility/device_prop.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/fill.hpp" @@ -92,7 +94,8 @@ struct ReferenceConvFwd : public device::BaseOperator in_right_pads_{input_right_pads}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, - out_element_op_{out_element_op} + out_element_op_{out_element_op}, + device_name_{ck::get_device_name()} { } @@ -112,6 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; + ::std::string device_name_; // the device which this conv is compared with }; struct Invoker : public device::BaseInvoker @@ -251,10 +255,39 @@ struct ReferenceConvFwd : public device::BaseOperator x); if constexpr(is_same_v) { - v_acc += ck::type_convert( - ck::type_convert(v_in)) * - ck::type_convert( - ck::type_convert(v_wei)); + if(arg.device_name_ == "gfx942") + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_in_bf16_big = + ck::type_convert(v_in); + ck::bhalf_t v_in_bf16_small = + ck::type_convert( + v_in - type_convert(v_in_bf16_big)); + ck::bhalf_t v_wei_bf16_big = + ck::type_convert(v_wei); + ck::bhalf_t v_wei_bf16_small = + ck::type_convert( + v_wei - type_convert(v_wei_bf16_big)); + + v_acc += ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_small) + + ck::type_convert(v_in_bf16_small) * + ck::type_convert(v_wei_bf16_big) + + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_big); + } + else + { + throw std::runtime_error( + "Unsupported device: " + arg.device_name_ + + " for tf32 computation"); + } } else { @@ -350,10 +383,41 @@ struct ReferenceConvFwd : public device::BaseOperator x); if constexpr(is_same_v) { - v_acc += ck::type_convert( - ck::type_convert(v_in)) * - ck::type_convert( - ck::type_convert(v_wei)); + if(arg.device_name_ == "gfx942") + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_in_bf16_big = + ck::type_convert(v_in); + ck::bhalf_t v_in_bf16_small = + ck::type_convert( + v_in - type_convert(v_in_bf16_big)); + ck::bhalf_t v_wei_bf16_big = + ck::type_convert(v_wei); + ck::bhalf_t v_wei_bf16_small = + ck::type_convert( + v_wei - + type_convert(v_wei_bf16_big)); + + v_acc += + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_small) + + ck::type_convert(v_in_bf16_small) * + ck::type_convert(v_wei_bf16_big) + + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_big); + } + else + { + throw std::runtime_error( + "Unsupported device: " + arg.device_name_ + + " for tf32 computation"); + } } else { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 8b9b973b2d..c5afebf75d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -6,6 +6,7 @@ #include #include +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -45,7 +46,8 @@ struct ReferenceGemm : public device::BaseOperator c_m_n_{c_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - c_element_op_{c_element_op} + c_element_op_{c_element_op}, + device_name_{ck::get_device_name()} { } @@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; + ::std::string device_name_; // the device which this gemm is compared with }; // Invoker @@ -142,12 +145,37 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - if constexpr(is_same_v && - is_same_v) - { // only for tf32 now - v_acc += - ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && + is_same_v) + { + if(arg.device_name_ == "gfx942") + { + v_acc += + ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_a_bf16_big = ck::type_convert(v_a); + ck::bhalf_t v_a_bf16_small = ck::type_convert( + v_a - type_convert(v_a_bf16_big)); + ck::bhalf_t v_b_bf16_big = ck::type_convert(v_b); + ck::bhalf_t v_b_bf16_small = ck::type_convert( + v_b - type_convert(v_b_bf16_big)); + + v_acc += ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_small) + + ck::type_convert(v_a_bf16_small) * + ck::type_convert(v_b_bf16_big) + + ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_big); + } + else + { + throw std::runtime_error("Unsupported device: " + arg.device_name_); + } } else { diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index cf30bc7dda..1e024818c6 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -82,9 +82,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // multiply and accumulate if constexpr(is_same_v && is_same_v) - { // only for tf32 now - v_acc += ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); + { +#if defined(__gfx942__) + v_acc += ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); +#elif defined(__gfx950__) + ck::bhalf_t v_a_bf16_big = ck::type_convert(v_a); + ck::bhalf_t v_a_bf16_small = + ck::type_convert(v_a - type_convert(v_a_bf16_big)); + ck::bhalf_t v_b_bf16_big = ck::type_convert(v_b); + ck::bhalf_t v_b_bf16_small = + ck::type_convert(v_b - type_convert(v_b_bf16_big)); + + v_acc += ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_small) + + ck::type_convert(v_a_bf16_small) * + ck::type_convert(v_b_bf16_big) + + ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_big); +#else + v_acc += type_convert(v_a) * type_convert(v_b); +#endif } else { diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 13f5cd1cda..8400b020f7 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,7 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; #endif @@ -228,7 +228,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -253,7 +253,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -280,7 +280,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -306,7 +306,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -331,7 +331,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -352,7 +352,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -373,7 +373,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -416,7 +416,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -439,7 +439,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif }