diff --git a/paddle/fluid/operators/controlflow/feed_op.cc b/paddle/fluid/operators/controlflow/feed_op.cc index 7d0d899e8b6c3f..bacf4df7009157 100644 --- a/paddle/fluid/operators/controlflow/feed_op.cc +++ b/paddle/fluid/operators/controlflow/feed_op.cc @@ -153,6 +153,8 @@ class FeedOp : public framework::OperatorWithKernel { feed_sparse_tensor.coalesced()); out_var->GetMutable()->SetIndicesDict( feed_sparse_tensor.GetIndicesDict()); + out_var->GetMutable()->SetKmaps( + feed_sparse_tensor.GetKmaps()); } else { PADDLE_THROW( phi::errors::Unimplemented("Only support DenseTensor, Strings, and " diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index 4d5917b451a818..1b9270b7835e7c 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -263,6 +263,7 @@ void SyncBatchNormCooKernel(const Context& dev_ctx, saved_variance, reserve_space); y->SetIndicesDict(x.GetIndicesDict()); + y->SetKmaps(x.GetKmaps()); } template diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 56e952623a1500..c78c364e62632b 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -121,6 +121,15 @@ intermediate: rulebook, counter backward : conv3d_grad +- op : conv3d_implicit_gemm + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key="") + output : Tensor(out) + infer_meta : + func : sparse::Conv3dImplicitGemmInferMeta + kernel : + func : conv3d_implicit_gemm{sparse_coo, dense -> sparse_coo} + layout : x + - op : divide args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/paddle/phi/core/kmap_cache.h b/paddle/phi/core/kmap_cache.h new file mode 100644 index 00000000000000..186226edf19060 --- /dev/null +++ b/paddle/phi/core/kmap_cache.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +struct KmapCache { + DenseTensor* out_in_map = nullptr; + DenseTensor* coords = nullptr; + DenseTensor* hashmap_keys = nullptr; + DenseTensor* hashmap_values = nullptr; + // std::vector* spatial_range; + + // destructor + ~KmapCache() { + if (out_in_map) { + delete out_in_map; + } + if (coords) { + delete coords; + } + if (hashmap_keys) { + delete hashmap_keys; + } + if (hashmap_values) { + delete hashmap_values; + } + } +}; + +} // namespace phi diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index 61c8b0c3d2a5b0..c59d09f653513a 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kmap_cache.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" @@ -244,6 +245,43 @@ class SparseCooTensor : public TensorBase, indices_dict_ = indices_dict; } + /// \brief set kmaps_ pointer + KmapCache* SetKmapCache(const std::string& key, const KmapCache& kmap) { + if (kmaps_ == nullptr) { + kmaps_ = std::make_shared>(); + kmaps_->insert({key, kmap}); + } + return &kmaps_->at(key); + } + + void SetKmaps( + const std::shared_ptr>& kmaps) { + kmaps_ = kmaps; + } + + std::shared_ptr> GetKmaps() const { + return kmaps_; + } + + const KmapCache* GetKmapCache(const std::string& key) const { + if (kmaps_ == nullptr) { + return nullptr; + } + const auto& iter = kmaps_->find(key); + if (iter == kmaps_->end()) { + return nullptr; + } + return &iter->second; + } + + void ClearKmaps() { + if (kmaps_ != nullptr) { + // set shared_ptr to nullptr, + // if no other shared_ptr point to it, it will be released. + kmaps_ = nullptr; + } + } + private: friend class DenseTensorUtils; @@ -265,6 +303,9 @@ class SparseCooTensor : public TensorBase, std::shared_ptr>> indices_dict_ = nullptr; + // Sparse conv will generate a kmap, which can be reused. + std::shared_ptr> kmaps_ = nullptr; + /* --------------------------- */ /* example: non zero element is scalar */ /* --------------------------- */ diff --git a/paddle/phi/infermeta/sparse/binary.cc b/paddle/phi/infermeta/sparse/binary.cc index 2ed540c0e0c4db..930eefaff534db 100644 --- a/paddle/phi/infermeta/sparse/binary.cc +++ b/paddle/phi/infermeta/sparse/binary.cc @@ -121,6 +121,43 @@ void Conv3dInferMeta(const MetaTensor& x, counter->set_dims({1}); } +void Conv3dImplicitGemmInferMeta(const MetaTensor& x, + const MetaTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + MetaTensor* out) { + const auto& x_dims = x.dims(); + const bool is2D = x_dims.size() == 4 ? true : false; + const auto& kernel_dims = kernel.dims(); + + int rank = is2D ? 4 : 5; + std::vector out_dims_vec(rank, 1); + DDim out_dims = common::make_ddim(out_dims_vec); + + std::vector kernel_sizes(kernel_dims.size()); + for (int i = 0; i < kernel_dims.size(); i++) { + kernel_sizes[i] = static_cast(kernel_dims[i]); + } + + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + // the out shape of subm_conv is same as input shape + // reset the padding=kernel_size/2 and strides=1 + ResetSubmKernelSizeAndStrides(kernel.dims(), &subm_paddings, &subm_strides); + } + + GetOutShape( + x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); + + out->set_dtype(x.dtype()); + out->set_dims(out_dims); + out->set_layout(x.layout()); +} + inline const std::vector PoolResetKernel( const std::vector& kernel_sizes, const int in_channels, diff --git a/paddle/phi/infermeta/sparse/binary.h b/paddle/phi/infermeta/sparse/binary.h index a2c3e6fe5705c5..cc215b0d9dafd6 100644 --- a/paddle/phi/infermeta/sparse/binary.h +++ b/paddle/phi/infermeta/sparse/binary.h @@ -34,6 +34,16 @@ void Conv3dInferMeta(const MetaTensor& x, MetaTensor* rulebook, MetaTensor* counter); +void Conv3dImplicitGemmInferMeta(const MetaTensor& x, + const MetaTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + MetaTensor* out); + void Pool3dInferMeta(const MetaTensor& x, const std::vector& kernel_sizes, const std::vector& paddings, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 304fd3cef793a2..5b9ec9c129b4ba 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -42,6 +42,7 @@ file( if(APPLE OR WIN32) list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") + list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") endif() if(NOT WITH_DGC) diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index e250973ba4543e..b4a831643b3f2c 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include "paddle/common/ddim.h" +#include "paddle/phi/core/kmap_cache.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/paddle/phi/kernels/sparse/batch_norm_kernel.cc b/paddle/phi/kernels/sparse/batch_norm_kernel.cc index 04ab36892513cb..857d815c5c4815 100644 --- a/paddle/phi/kernels/sparse/batch_norm_kernel.cc +++ b/paddle/phi/kernels/sparse/batch_norm_kernel.cc @@ -59,6 +59,7 @@ void BatchNormCooKernel(const Context& dev_ctx, saved_variance, reserve_space); y->SetIndicesDict(x.GetIndicesDict()); + y->SetKmaps(x.GetKmaps()); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/elementwise_kernel.h b/paddle/phi/kernels/sparse/elementwise_kernel.h index fe2d22ed1072d4..4c5cf7ba8ba467 100644 --- a/paddle/phi/kernels/sparse/elementwise_kernel.h +++ b/paddle/phi/kernels/sparse/elementwise_kernel.h @@ -91,6 +91,7 @@ void ElementWiseAddDenseKernel(const Context& dev_ctx, EmptyLikeCooKernel(dev_ctx, x, out); phi::AddKernel(dev_ctx, x.values(), y, out->mutable_values()); out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } else { PADDLE_THROW( errors::Unimplemented("Not support Sparse + Dense in GPU mode")); diff --git a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index 67785d89505b4f..31d8780a750b0c 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -171,6 +171,7 @@ void CoalesceCooGPUKernel(const GPUContext& dev_ctx, out->SetMember(out_indices, out_values, x.dims(), true); out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } template diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu new file mode 100644 index 00000000000000..1a3b867be48615 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu @@ -0,0 +1,208 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" +#include "paddle/phi/kernels/sparse/gpu/conv_kernel_impl.cuh" +#include "paddle/phi/kernels/sparse/gpu/sparse_conv_hashmap.cuh" + +#include "glog/logging.h" + +namespace phi { +namespace sparse { + +template +void Conv3dImplicitGemmGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + SparseCooTensor* out) { + // Currently, only support x.layout is NDHWC, subm = true, stride = 1, groups + // = 1, dilations = 1 + PADDLE_ENFORCE_EQ( + subm, + true, + phi::errors::InvalidArgument("The subm must be true, but received %s.", + subm ? "true" : "false")); + PADDLE_ENFORCE_EQ(groups, + 1, + phi::errors::InvalidArgument( + "The group must be 1, but received %d.", groups)); + + const auto& x_dims = x.dims(); + const auto& kernel_dims = kernel.dims(); + const bool is2D = x_dims.size() == 4 ? true : false; + + if (is2D) { + PADDLE_ENFORCE_EQ( + (kernel_dims.size() == 4), + true, + phi::errors::InvalidArgument( + "For 2D case, the size of kernel_dims must be 4, but received %d.", + kernel_dims.size())); + PADDLE_ENFORCE_EQ( + (strides.size() == 2 && strides[0] == 1 && strides[1] == 1), + true, + phi::errors::InvalidArgument( + "The strides must be 1, but received %d, %d.", + strides[0], + strides[1])); + PADDLE_ENFORCE_EQ( + (dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1), + true, + phi::errors::InvalidArgument( + "The dilations must be 1, but received %d, %d.", + dilations[0], + dilations[1])); + + } else { + PADDLE_ENFORCE_EQ( + (kernel_dims.size() == 5), + true, + phi::errors::InvalidArgument( + "For 3D case, the size of kernel_dims must be 5, but received %d.", + kernel_dims.size())); + PADDLE_ENFORCE_EQ((strides.size() == 3 && strides[0] == 1 && + strides[1] == 1 && strides[2] == 1), + true, + phi::errors::InvalidArgument( + "The strides must be 1, but received %d, %d, %d.", + strides[0], + strides[1], + strides[2])); + PADDLE_ENFORCE_EQ((dilations.size() == 3 && dilations[0] == 1 && + dilations[1] == 1 && dilations[2] == 1), + true, + phi::errors::InvalidArgument( + "The dilations must be 1, but received %d, %d, %d.", + dilations[0], + dilations[1], + dilations[2])); + } + + int kernel_volume = is2D ? kernel_dims[0] * kernel_dims[1] + : kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; + int in_channels = is2D ? kernel_dims[2] : kernel_dims[3]; + int out_channels = is2D ? kernel_dims[3] : kernel_dims[4]; + + int rank = is2D ? 4 : 5; + std::vector out_dims_vec(rank, 1); + DDim out_dims = common::make_ddim(out_dims_vec); + + std::vector kernel_sizes(kernel_dims.size()); + for (int i = 0; i < kernel_dims.size(); i++) { + kernel_sizes[i] = kernel_dims[i]; + } + + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + // the out shape of subm_conv is same as input shape + // reset the padding=kernel_size/2 and strides=1 + phi::funcs::sparse::ResetSubmKernelSizeAndStrides( + kernel.dims(), &subm_paddings, &subm_strides); + } + + phi::funcs::sparse::GetOutShape( + x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); + + // Set the output tensor + if (subm) { + DenseTensor out_indices = phi::EmptyLike(dev_ctx, x.indices()); + int tmpidx = is2D ? 3 : 4; + DenseTensor out_values = + phi::Empty(dev_ctx, {x.nnz(), kernel_sizes[tmpidx]}); + phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices); + out->SetMember(out_indices, out_values, out_dims, false); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "The subm must be true, but received %s.", subm ? "true" : "false")); + } + + build_sparse_conv_kmap( + dev_ctx, x, key, kernel_sizes, strides, kernel_volume, is2D, out); + + auto* out_kmap_cache_ptr = out->GetKmapCache(key); + + DenseTensor kernel_transpose = phi::EmptyLike(dev_ctx, kernel); + std::vector perm; + if (is2D) { + perm = {1, 0, 2, 3}; + } else { + perm = {2, 1, 0, 3, 4}; + } + phi::funcs::TransposeGPUKernelDriver( + dev_ctx, kernel, perm, &kernel_transpose); + + conv_forward_implicit_gemm_cuda(dev_ctx, + x.values(), + kernel_transpose, + *(out_kmap_cache_ptr->out_in_map), + out->nnz(), + out_channels, + *(out->mutable_values())); +} + +/** + * x: the input SparseCooTensor, shape is (N, D, H, W, C) + * kernel: the weight data, shape is (D, H, W, C, OC) + * out: the output SparseCooTensor, shape is (N, D, H, W, OC) + * rulebook: return rulebook if key is not vailed else return nullptr + * counter: return counter if key is not vailed else return nullptr + **/ +template +void Conv3dImplicitGemmKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.indices().dtype(), "Conv3dImplicitGemmGPUKernel", ([&] { + // Conv3dImplicitGemmGPUKernel(dev_ctx, + Conv3dImplicitGemmGPUKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + key, + out); + })); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(conv3d_implicit_gemm, + GPU, + ALL_LAYOUT, + phi::sparse::Conv3dImplicitGemmKernel, + float, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel_impl.cuh b/paddle/phi/kernels/sparse/gpu/conv_kernel_impl.cuh new file mode 100644 index 00000000000000..33e5e3a54c1849 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel_impl.cuh @@ -0,0 +1,1273 @@ +#include +#include "paddle/phi/common/float16.h" +#include "paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh" + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) +{ + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + + +// conv_forward_cuda_m128n16k16_m64n16k16_m16n16k16_f16f16f32 +template +__global__ void __launch_bounds__(64) conv_forward_cuda_setting1_mode0_f16f16f32(int M, int K_original, int N, int kernel_volume, half *__restrict__ A, half *__restrict__ B, int *__restrict__ out_in_map, half *__restrict__ C) +{ + // warning: kernel could not work with K_original < 32! + const int K_tile = 16; // min(16, K_original); + int K_tile_padded = K_tile * ((K_original + K_tile - 1) / K_tile); + int K_implicit = K_tile_padded * kernel_volume; + + float C_warp[32]; + __shared__ half A_shared[5120]; + __shared__ half B_shared[640]; + half A_shared_warp[32]; + half B_shared_warp[8]; + for (int i0_0_3_init = 0; i0_0_3_init < 4; ++i0_0_3_init) + { + for (int i = 0; i < 8; ++i) + { + C_warp[(i0_0_3_init * 8) + i] = 0.0; + }; + } + + int j_factors1 = (N + 15) / 16 / 1; + int *out_in_map_ptr = out_in_map + (blockIdx.x / j_factors1 * 128 + threadIdx.y * 16 + threadIdx.x / 2) * kernel_volume + ((threadIdx.y * 256) % 16) / K_tile_padded + ((threadIdx.x * 8) % 16) / K_tile_padded; + half *A_ptr = A + ((threadIdx.y * 256 % 16) % K_tile_padded) + ((threadIdx.x * 8 % 16) % K_tile_padded); + half *B_ptr = B + (blockIdx.x % j_factors1) * 16 + threadIdx.y * 256 / 16 * N + threadIdx.x * 8 / 16 * N + (threadIdx.x * 8) % 16; + int reorder_loc_offset = blockIdx.x / j_factors1 * 8 * 16 + (threadIdx.y % 2) * 4 * 16 + (threadIdx.x / 4); + half *C_ptr = C + + (blockIdx.x % j_factors1) * 16 + threadIdx.y / 2 * 16 + (threadIdx.x % 4) * 2; + + int A_ld_start, A_ld_amount, A_ld_bound, A_pred_guard; + int B_ld_start, B_ld_amount, B_ld_bound, B_pred_guard, B_ld_amount_N, B_ld_K_bound; + bool B_ld_K; + if constexpr (N_ld_check || K_ld_check) + { + B_ld_start = (blockIdx.x % j_factors1) * 16 + (threadIdx.x * 8) % 16; + B_ld_amount_N = max(0, min(B_ld_start + 8, N) - B_ld_start); + B_ld_K_bound = K_original; + } + else + B_pred_guard = 1; + + //+ (threadIdx.x / 4) * N; + for (int i2_0_0 = 0; i2_0_0 < K_implicit / K_tile; ++i2_0_0) + + { + + if constexpr (K_ld_check) + { + A_ld_start = (i2_0_0 * K_tile % K_tile_padded) + ((threadIdx.x * 8) % 16); + A_ld_amount = max(0, min(A_ld_start + 8, K_original) - A_ld_start); + A_ld_bound = A_ld_amount / (K_ld_factor / 2); + A_pred_guard = 0; + for (int i = 0; i < A_ld_bound; i++) + A_pred_guard |= (1 << i); + } + else + { + A_pred_guard = 1; + } + + if constexpr (K_ld_check || N_ld_check) + { + B_ld_K = ((i2_0_0 * K_tile % K_tile_padded) + threadIdx.x * 8 / 16) < B_ld_K_bound; + B_ld_amount = B_ld_amount_N * (int)B_ld_K; + B_ld_bound = B_ld_amount / (N_ld_factor / 2); + B_pred_guard = 0; + for (int i = 0; i < B_ld_bound; i++) + B_pred_guard |= (1 << i); + } + + int *out_in_map_ptr_local = out_in_map_ptr + i2_0_0 * K_tile / K_tile_padded; + half *A_ptr_local = A_ptr + (i2_0_0 * K_tile % K_tile_padded); + half *B_ptr_local; + if constexpr (K_ld_check) + B_ptr_local = B_ptr + (i2_0_0 * K_tile / K_tile_padded * K_original + i2_0_0 * K_tile % K_tile_padded) * N; + else + B_ptr_local = B_ptr + i2_0_0 * K_tile * N; + __syncthreads(); + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) + { + + int input_idx = out_in_map_ptr_local[ax0_ax1_fused_0 * 32 * kernel_volume + (ax0_ax1_fused_0 * 512 % 16) / K_tile_padded]; + + if (input_idx != -1) + { + uint4 A_loaded = make_uint4(0, 0, 0, 0); + global_load(A_loaded, A_ptr_local + input_idx * K_original + ((ax0_ax1_fused_0 * 512 % 16) % K_tile_padded), A_pred_guard); + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 1280) + (((int)threadIdx.y) * 640)) + ((((int)threadIdx.x) >> 1) * 40)) + ((((int)threadIdx.x) & 1) * 8))) = A_loaded; + } + else + { + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 1280) + (((int)threadIdx.y) * 640)) + ((((int)threadIdx.x) >> 1) * 40)) + ((((int)threadIdx.x) & 1) * 8))) = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f))); + } + } + + if (threadIdx.y == 0) + { + uint4 B_loaded = make_uint4(0, 0, 0, 0); + global_load(B_loaded, B_ptr_local, B_pred_guard); + *(uint4 *)(B_shared + (((((int)threadIdx.y) * 640) + ((((int)threadIdx.x) >> 1) * 40)) + ((((int)threadIdx.x) & 1) * 8))) = B_loaded; + } + + __syncthreads(); + __syncthreads(); + for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) + { + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(A_shared[((((int)threadIdx.y) * 2560) + (ax0_0 * 640))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(B_shared[0])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(B_shared_warp + 0))[0]), "=r"(((unsigned *)(B_shared_warp + 0))[1]), "=r"(((unsigned *)(B_shared_warp + 0))[2]), "=r"(((unsigned *)(B_shared_warp + 0))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + for (int i0_0_3 = 0; i0_0_3 < 4; ++i0_0_3) + { +#if __CUDA_ARCH__ >= 800 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "r"(((unsigned *)(B_shared_warp + 0))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + 4))[0]), "r"(((unsigned *)(B_shared_warp + 4))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } +#elif __CUDA_ARCH__ >= 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + 4))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + 2))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + 6))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + for (int ax0_0_1 = 0; ax0_0_1 < 4; ++ax0_0_1) + { + + int reorder_loc_offset_local = reorder_loc_offset + ax0_0_1 * 16; + for (int local_id = 0; local_id < 8; ++local_id) + { + + int reorder_location_cur = reorder_loc_offset_local + (((local_id / 2) % 2) * 8); + if constexpr (N_ld_check) + { + bool C_wb_enable = ((blockIdx.x % j_factors1) * 16 + threadIdx.y / 2 * 16 + (threadIdx.x % 4) * 2 + (local_id % 2) + (local_id / 4) * 8) < N; + if (C_wb_enable && reorder_location_cur < M) + C_ptr[reorder_location_cur * N + + (local_id % 2) + (local_id / 4) * 8] = __float2half(C_warp[(ax0_0_1 * 8) + local_id]); + } + else + { + if (reorder_location_cur < M) + C_ptr[reorder_location_cur * N + + (local_id % 2) + (local_id / 4) * 8] = __float2half(C_warp[(ax0_0_1 * 8) + local_id]); + } + }; + } +} + +// conv_forward_cuda_m128n16k32_m64n16k32_m16n16k16_f16f16f32 +__global__ void __launch_bounds__(64) conv_forward_cuda_setting2_mode0_f16f16f32(int M, int K_original, int N, int kernel_volume, half *__restrict__ A, half *__restrict__ B, int *__restrict__ out_in_map, half *__restrict__ C) +{ + // warning: kernel could not work with K_original < 32! + int K_implicit = K_original * kernel_volume; + float C_warp[32]; + __shared__ half A_shared[5120]; + __shared__ half B_shared[1280]; + half A_shared_warp[32]; + half B_shared_warp[8]; + for (int i0_0_3_init = 0; i0_0_3_init < 4; ++i0_0_3_init) + { + for (int i = 0; i < 8; ++i) + { + C_warp[(i0_0_3_init * 8) + i] = 0.0; + }; + } + + // hoisting shared pointer offsets + int j_factors1 = N / 16 / 1; + int *out_in_map_ptr = out_in_map + (blockIdx.x / j_factors1 * 128 + threadIdx.y * 8 + threadIdx.x / 4) * kernel_volume + ((threadIdx.y * 256) % 32) / K_original + ((threadIdx.x * 8) % 32) / K_original; + half *A_ptr = A + ((threadIdx.y * 256 % 32) % K_original) + ((threadIdx.x * 8 % 32) % K_original); + half *B_ptr = B + (blockIdx.x % j_factors1) * 16 + threadIdx.y * 256 / 16 * N + threadIdx.x * 8 / 16 * N + (threadIdx.x * 8) % 16; + int reorder_loc_offset = blockIdx.x / j_factors1 * 8 * 16 + (threadIdx.y % 2) * 4 * 16 + (threadIdx.x / 4); + half *C_ptr = C + + (blockIdx.x % j_factors1) * 16 + threadIdx.y / 2 * 16 + (threadIdx.x % 4) * 2; + for (int i2_0_0 = 0; i2_0_0 < K_implicit / 32; ++i2_0_0) + + { + + int *out_in_map_ptr_local = out_in_map_ptr + i2_0_0 * 32 / K_original; + half *A_ptr_local = A_ptr + (i2_0_0 * 32 % K_original); + half *B_ptr_local = B_ptr + i2_0_0 * 32 * N; + __syncthreads(); + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) + { + + int input_idx = out_in_map_ptr_local[ax0_ax1_fused_0 * 16 * kernel_volume + (ax0_ax1_fused_0 * 512 % 32) / K_original]; + + if (input_idx != -1) + { + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 640) + (((int)threadIdx.y) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + ((((int)threadIdx.x) & 3) * 8))) = + *(uint4 *)(A_ptr_local + input_idx * K_original + ((ax0_ax1_fused_0 * 512 % 32) % K_original)); + } + else + { + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 640) + (((int)threadIdx.y) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + ((((int)threadIdx.x) & 3) * 8))) = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f))); + } + } + + *(uint4 *)(B_shared + (((((int)threadIdx.y) * 640) + ((((int)threadIdx.x) >> 1) * 40)) + ((((int)threadIdx.x) & 1) * 8))) = + *(uint4 *)(B_ptr_local); + + __syncthreads(); + for (int i2_0_1 = 0; i2_0_1 < 2; ++i2_0_1) + { + for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) + { + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (i2_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(B_shared[(i2_0_1 * 640)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(B_shared_warp + 0))[0]), "=r"(((unsigned *)(B_shared_warp + 0))[1]), "=r"(((unsigned *)(B_shared_warp + 0))[2]), "=r"(((unsigned *)(B_shared_warp + 0))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + for (int i0_0_3 = 0; i0_0_3 < 4; ++i0_0_3) + { + +#if __CUDA_ARCH__ >= 800 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "r"(((unsigned *)(B_shared_warp + 0))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + 4))[0]), "r"(((unsigned *)(B_shared_warp + 4))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } +#elif __CUDA_ARCH__ >= 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + 4))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "=f"(((float *)(C_warp + (i0_0_3 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + 2))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[0]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[1]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[2]), "f"(((float *)(C_warp + (i0_0_3 * 8)))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + 6))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 8) + 4)))[3])); + } +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + } + for (int ax0_0_1 = 0; ax0_0_1 < 4; ++ax0_0_1) + { + + int reorder_loc_offset_local = reorder_loc_offset + ax0_0_1 * 16; + for (int local_id = 0; local_id < 8; ++local_id) + { + + int reorder_location_cur = reorder_loc_offset_local + (((local_id / 2) % 2) * 8); + if (reorder_location_cur < M) + C_ptr[reorder_location_cur * N + + (local_id % 2) + (local_id / 4) * 8] = __float2half(C_warp[(ax0_0_1 * 8) + local_id]); + }; + } +} + +// conv_forward_cuda_m128n64k32_m64n32k32_m16n16k16_f16f16f32 +__global__ void __launch_bounds__(128) conv_forward_cuda_setting3_mode0_f16f16f32(int M, int K_original, int N, int kernel_volume, half *__restrict__ A, half *__restrict__ B, int *__restrict__ out_in_map, half *__restrict__ C) +{ + int K_implicit = K_original * kernel_volume; + float C_warp[64]; + __shared__ half A_shared[5120]; + __shared__ half B_shared[2304]; + half A_shared_warp[32]; + half B_shared_warp[16]; + for (int i0_0_3_init = 0; i0_0_3_init < 4; ++i0_0_3_init) + { + for (int i1_0_4_init = 0; i1_0_4_init < 2; ++i1_0_4_init) + { + for (int i = 0; i < 8; ++i) + { + C_warp[((i0_0_3_init * 16) + (i1_0_4_init * 8)) + i] = 0.0; + }; + } + } + + // hoisting shared pointer offsets + int j_factors1 = N / 16 / 4; + int *out_in_map_ptr = out_in_map + (blockIdx.x / j_factors1 * 128 + threadIdx.y * 8 + threadIdx.x / 4) * kernel_volume + ((threadIdx.y * 256) % 32) / K_original + ((threadIdx.x * 8) % 32) / K_original; + half *A_ptr = A + ((threadIdx.y * 256 % 32) % K_original) + ((threadIdx.x * 8 % 32) % K_original); + half *B_ptr = B + (blockIdx.x % j_factors1) * 64 + threadIdx.y * 256 / 64 * N + threadIdx.x * 8 / 64 * N + (threadIdx.x * 8) % 64; + int reorder_loc_offset = blockIdx.x / j_factors1 * 8 * 16 + (threadIdx.y % 2) * 4 * 16 + (threadIdx.x / 4); + half *C_ptr = C + + (blockIdx.x % j_factors1) * 64 + threadIdx.y / 2 * 32 + (threadIdx.x % 4) * 2; + + int B_kernel_offset = threadIdx.y * 256 / 64 + threadIdx.x * 8 / 64; + + for (int i2_0_0 = 0; i2_0_0 < K_implicit / 32; ++i2_0_0) + + { + + int *out_in_map_ptr_local = out_in_map_ptr + i2_0_0 * 32 / K_original; + half *A_ptr_local = A_ptr + (i2_0_0 * 32 % K_original); + half *B_ptr_local = B_ptr + i2_0_0 * 32 * N; + + __syncthreads(); + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) + { + + int input_idx = out_in_map_ptr_local[ax0_ax1_fused_0 * 32 * kernel_volume + (ax0_ax1_fused_0 * 1024 % 32) / K_original]; + + if (input_idx != -1) + { + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 1280) + (((int)threadIdx.y) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + ((((int)threadIdx.x) & 3) * 8))) = + *(uint4 *)(A_ptr_local + input_idx * K_original + ((ax0_ax1_fused_0 * 1024 % 32) % K_original)); + } + else + { + *(uint4 *)(A_shared + ((((ax0_ax1_fused_0 * 1280) + (((int)threadIdx.y) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + ((((int)threadIdx.x) & 3) * 8))) = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f))); + } + } + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 2; ++ax0_ax1_fused_0_1) + { + // Shang: skip loading B + int B_kernel_offset_local = (B_kernel_offset + i2_0_0 * 32 + ax0_ax1_fused_0_1 * 1024 / 64) / K_original; + *(uint4 *)(B_shared + ((((ax0_ax1_fused_0_1 * 1152) + (((int)threadIdx.y) * 288)) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8))) = + *(uint4 *)(B_ptr_local + ax0_ax1_fused_0_1 * 1024 * N / 64); + } + __syncthreads(); + + for (int i2_0_1 = 0; i2_0_1 < 2; ++i2_0_1) + { + for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) + { + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (i2_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((i2_0_1 * 1152) + ((((int)threadIdx.y) >> 1) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))); +#if __CUDA_ARCH__ >= 750 + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + for (int i0_0_3 = 0; i0_0_3 < 4; ++i0_0_3) + { + for (int i1_0_4 = 0; i1_0_4 < 2; ++i1_0_4) + { +#if __CUDA_ARCH__ >= 800 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (i1_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (i1_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((i1_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((i1_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3])); + } +#elif __CUDA_ARCH__ >= 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (i1_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i0_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + ((i1_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + ((i1_0_4 * 8) + 2)))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i0_0_3 * 16) + (i1_0_4 * 8))))[3])); + } + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" + : "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(A_shared_warp + ((i0_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(B_shared_warp + ((i1_0_4 * 8) + 6)))[0]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i0_0_3 * 16) + (i1_0_4 * 8)) + 4)))[3])); + } +#else + #pragma message("FP16 kernels will not be compiled for SM75-.") +#endif + } + } + } + } + for (int ax0_0_1 = 0; ax0_0_1 < 4; ++ax0_0_1) + { + + int reorder_loc_offset_local = reorder_loc_offset + ax0_0_1 * 16; + for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) + { + for (int local_id = 0; local_id < 8; ++local_id) + { + + int reorder_location_cur = reorder_loc_offset_local + (((local_id / 2) % 2) * 8); + if (reorder_location_cur < M) + C_ptr[reorder_location_cur * N + //+ ax0_0_1 * N / 16 * 256 + + ax1_0_1 * 16 + //+ (((local_id / 2) % 2) * 8) * N + + (local_id % 2) + (local_id / 4) * 8] = __float2half(C_warp[((ax0_0_1 * 16) + (ax1_0_1 * 8)) + local_id]); + }; + } + } +} + +// conv_forward_cuda_m128n16k16_f32f32f32 +template +__global__ void __launch_bounds__(64) conv_forward_cuda_setting1_mode0_f32f32f32(int M, int K_original, int N, int kernel_volume, float* __restrict__ A, float* __restrict__ B, int* __restrict__ out_in_map, float* __restrict__ C) +{ + + const int K_tile = 16; + int K_tile_padded = K_tile * ((K_original + K_tile - 1) / K_tile); + int K_implicit = K_tile_padded * kernel_volume; + + float C_local[32]; + __shared__ float A_shared[2048]; + __shared__ float B_shared[256]; + + #pragma unroll + for (int i = 0; i < 32; ++i) + { + C_local[i] = 0.0; + } + + int K_loops = K_implicit / 16; + int block_num_n = (N - 1) / 16 + 1; + int blockIdx_m = (int)blockIdx.x / block_num_n; + int blockIdx_n = (int)blockIdx.x % block_num_n; + int threadIdx_x = (int)threadIdx.x; + + // hoisting shared pointer offsets + int * out_in_map_ptr = out_in_map + + (blockIdx_m * 128 + (threadIdx_x / (16/4)))* kernel_volume; + + float * B_ptr = B + + (threadIdx_x / (16/4)) * N + + (blockIdx_n * 16) + ((threadIdx_x * 4) % 16); + + float * A_shared_ptr = A_shared + (threadIdx_x * 4); + float * A_shared_reduce_ptr = A_shared + ((threadIdx_x / 4) * 16); + float * B_shared_ptr = B_shared + (threadIdx_x * 4); + float * B_shared_reduce_ptr = B_shared + (threadIdx_x % 4); + + int location_offset = blockIdx_m * 128 + (threadIdx_x / 4); // C_m_offset + int C_n_offset = blockIdx_n * 16 + (threadIdx_x % 4); + + int channel_offset_A = ((threadIdx_x * 4) % 16); + + int A_ld_start, A_ld_amount, A_ld_bound, A_pred_guard; + int B_ld_start, B_ld_amount, B_ld_bound, B_pred_guard, B_ld_amount_N, B_ld_K_bound; + bool B_ld_K; + if constexpr (N_ld_check || K_ld_check) + { + B_ld_start = (blockIdx_n * 16) + ((threadIdx_x * 4) % 16); + B_ld_amount_N = max(0, min(B_ld_start + 4, N) - B_ld_start); + B_ld_K_bound = K_original; + } + else + B_pred_guard = 1; + + #pragma unroll + for (int k_0 = 0; k_0 < K_loops; ++k_0) { + + { + if constexpr (K_ld_check) + { + A_ld_start = (k_0 * K_tile % K_tile_padded) + ((threadIdx.x * 4) % 16); // Channel_offset + A_ld_amount = max(0, min(A_ld_start + 4, K_original) - A_ld_start); + A_ld_bound = A_ld_amount / (K_ld_factor / 4); + A_pred_guard = 0; + for (int i = 0; i < A_ld_bound; i++) + A_pred_guard |= (1 << i); + } + else + { + A_pred_guard = 1; + } + + if constexpr (K_ld_check || N_ld_check) + { + B_ld_K = ((k_0 * K_tile % K_tile_padded) + threadIdx.x * 4 / 16) < B_ld_K_bound; + B_ld_amount = B_ld_amount_N * (int)B_ld_K; + B_ld_bound = B_ld_amount / (N_ld_factor / 4); + B_pred_guard = 0; + for (int i = 0; i < B_ld_bound; i++) + B_pred_guard |= (1 << i); + } + + int* out_in_map_ptr_local = out_in_map_ptr + k_0 * 16 / K_tile_padded; + float* A_ptr_local = A + (k_0 * 16 % K_tile_padded) + channel_offset_A; + + float* B_ptr_local; + if constexpr (K_ld_check) + B_ptr_local = B_ptr + (k_0 * K_tile / K_tile_padded * K_original + k_0 * K_tile % K_tile_padded) * N; + else + B_ptr_local = B_ptr + k_0 * K_tile * N; + + __syncthreads(); + #pragma unroll + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) + { + + int input_idx = *(out_in_map_ptr_local + (ax0_ax1_fused_0 *16) * kernel_volume); + if (input_idx != -1) + { + uint4 A_loaded = make_uint4(0, 0, 0, 0); + global_load(A_loaded, A_ptr_local + (input_idx * K_original) , A_pred_guard); + *(uint4 *)(A_shared_ptr + (ax0_ax1_fused_0 * 256)) = A_loaded; + } + else + { + *(uint4*)(A_shared_ptr + (ax0_ax1_fused_0 * 256)) = make_uint4(0, 0, 0, 0); + } + } + + #pragma unroll + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 1; ++ax0_ax1_fused_0_1) + { + uint4 B_loaded = make_uint4(0, 0, 0, 0); + global_load(B_loaded, B_ptr_local + (ax0_ax1_fused_0_1 * 16) * N, B_pred_guard); + *(uint4 *)(B_shared_ptr + (ax0_ax1_fused_0_1 * 256)) = B_loaded; + } + + __syncthreads(); + #pragma unroll + for (int k_1 = 0; k_1 < ( 16 / 4); ++k_1) + { + #pragma unroll + for (int k_2 = 0; k_2 < 4; ++k_2) + { + int vk_in_block = (k_1 << 2) + k_2; + #pragma unroll + for (int i = 0; i < 32; ++i) + { + C_local[i] = C_local[i] + + A_shared_reduce_ptr[((i / 4) * 16) * 16 + vk_in_block] + * B_shared_reduce_ptr[(vk_in_block * 16) + ((i % 4) * 4)]; + + } + } + } + } + } + + #pragma unroll + for (int i = 0; i < 32; ++i) + { + int location_cur = location_offset + ((i / 4) * 16); + int vn = C_n_offset + ((i % 4) * 4); + + if constexpr (N_ld_check) + { + if (vn < N && location_cur < M) + C[location_cur * N + vn] = C_local[i]; + } + else + { + if (location_cur < M) + C[location_cur * N + vn] = C_local[i]; + } + } +} + +// conv_forward_cuda_m128n16k32_f32f32f32 +__global__ void __launch_bounds__(64) conv_forward_cuda_setting2_mode0_f32f32f32(int M, int K_original, int N, int kernel_volume, float* __restrict__ A, float* __restrict__ B, int* __restrict__ out_in_map, float* __restrict__ C) +{ + float C_local[32]; + __shared__ float A_shared[4096]; + __shared__ float B_shared[512]; + + #pragma unroll + for (int i = 0; i < 32; ++i) + { + C_local[i] = 0.0; + } + + int K_loops = (K_original * kernel_volume - 1) / 32 + 1; + int block_num_n = (N - 1) / 16 + 1; + int blockIdx_m = (int)blockIdx.x / block_num_n; + int blockIdx_n = (int)blockIdx.x % block_num_n; + int threadIdx_x = (int)threadIdx.x; + + // hoisting shared pointer offsets + int * out_in_map_ptr = out_in_map + + (blockIdx_m * 128 + (threadIdx_x / (32/4)))* kernel_volume; + + float * B_ptr = B + + (threadIdx_x / (16/4)) * N + + (blockIdx_n * 16) + ((threadIdx_x * 4) % 16); + + float * A_shared_ptr = A_shared + (threadIdx_x * 4); + float * A_shared_reduce_ptr = A_shared + ((threadIdx_x / 4) * 32); + float * B_shared_ptr = B_shared + (threadIdx_x * 4); + float * B_shared_reduce_ptr = B_shared + (threadIdx_x % 4); + + int location_offset = blockIdx_m * 128 + (threadIdx_x / 4); // C_m_offset + int C_n_offset = blockIdx_n * 16 + (threadIdx_x % 4); + + int channel_offset_A = ((threadIdx_x * 4) % 32); // mod K_tile=32 + + #pragma unroll + for (int k_0 = 0; k_0 < K_loops; ++k_0) { + + int channel_offset = k_0 % (K_original / 32) * 32 + channel_offset_A; + int kernel_offset = k_0 / (K_original / 32); + int *out_in_map_ptr_k = out_in_map_ptr + kernel_offset; + + { + __syncthreads(); + #pragma unroll + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 16; ++ax0_ax1_fused_0) + { + + int input_idx = *(out_in_map_ptr_k + (ax0_ax1_fused_0 *8) * kernel_volume); + if (input_idx != -1) + { + + *(float4*)(A_shared_ptr + (ax0_ax1_fused_0 * 256)) = // ax0_ax1_fused_0 * elements loaded in each loop + *(float4*)(A + (input_idx * K_original) + channel_offset); + + } + else { + + *(float4*)(A_shared_ptr + (ax0_ax1_fused_0 * 256)) = make_float4(0.0, 0.0, 0.0, 0.0); + + } + } + + #pragma unroll + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 2; ++ax0_ax1_fused_0_1) + { + + *(float4*)(B_shared_ptr + (ax0_ax1_fused_0_1 * 256)) = // ax0_ax1_fused_0_1 * elements loaded in each loop + *(float4*)(B_ptr + ((k_0 * 32) + (ax0_ax1_fused_0_1 * 16)) * N); + + } + + __syncthreads(); + #pragma unroll + for (int k_1 = 0; k_1 < ( 32 / 4); ++k_1) + { + #pragma unroll + for (int k_2 = 0; k_2 < 4; ++k_2) + { + int vk_in_block = (k_1 << 2) + k_2; + #pragma unroll + for (int i = 0; i < 32; ++i) + { + C_local[i] = C_local[i] + + A_shared_reduce_ptr[((i / 4) * 16) * 32 + vk_in_block] + * B_shared_reduce_ptr[(vk_in_block * 16) + ((i % 4) * 4)]; + + } + } + } + } + } + + #pragma unroll + for (int i = 0; i < 32; ++i) + { + int location_cur = location_offset + ((i / 4) * 16); + int vn = C_n_offset + ((i % 4) * 4); + if (location_cur < M) + C[location_cur * N + vn] = C_local[i]; + } +} + +// conv_forward_cuda_m128n64k32_f32f32f32 +__global__ void __launch_bounds__(128) conv_forward_cuda_setting3_mode0_f32f32f32(int M, int K_original, int N, int kernel_volume, float* __restrict__ A, float* __restrict__ B, int* __restrict__ out_in_map, float* __restrict__ C) +{ + float C_local[64]; + __shared__ float A_shared[4096]; + __shared__ float B_shared[2048]; + + #pragma unroll + for (int i = 0; i < 64; ++i) + { + C_local[i] = 0.0; + } + + int K_loops = (K_original * kernel_volume - 1) / 32 + 1; + int block_num_n = (N - 1) / 64 + 1; + int blockIdx_m = (int)blockIdx.x / block_num_n; + int blockIdx_n = (int)blockIdx.x % block_num_n; + int threadIdx_x = (int)threadIdx.x; + + // hoisting shared pointer offsets + int * out_in_map_ptr = out_in_map + + (blockIdx_m * 128 + (threadIdx_x / (32/4)))* kernel_volume; + + float * B_ptr = B + + (threadIdx_x / (64/4)) * N + + (blockIdx_n * 64) + ((threadIdx_x * 4) % 64); + + float * A_shared_ptr = A_shared + (threadIdx_x * 4); + float * A_shared_reduce_ptr = A_shared + ((threadIdx_x / 16) * 32); + float * B_shared_ptr = B_shared + (threadIdx_x * 4); + float * B_shared_reduce_ptr = B_shared + (threadIdx_x % 16); + + int location_offset = blockIdx_m * 128 + (threadIdx_x / 16); // C_m_offset + int C_n_offset = blockIdx_n * 64 + (threadIdx_x % 16); + + int channel_offset_A = ((threadIdx_x * 4) % 32); // mod K_tile=32 + + #pragma unroll + for (int k_0 = 0; k_0 < K_loops; ++k_0) { + + int channel_offset = k_0 % (K_original / 32) * 32 + channel_offset_A; + int kernel_offset = k_0 / (K_original / 32); + int *out_in_map_ptr_k = out_in_map_ptr + kernel_offset; + + { + __syncthreads(); + #pragma unroll + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) + { + + int input_idx = *(out_in_map_ptr_k + (ax0_ax1_fused_0 *16) * kernel_volume); + if (input_idx != -1) + { + + *(float4*)(A_shared_ptr + (ax0_ax1_fused_0 * 512)) = // ax0_ax1_fused_0 * elements loaded in each loop + *(float4*)(A + (input_idx * K_original) + channel_offset); + + } + else { + + *(float4*)(A_shared_ptr + (ax0_ax1_fused_0 * 512)) = make_float4(0.0, 0.0, 0.0, 0.0); + + } + } + + #pragma unroll + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 4; ++ax0_ax1_fused_0_1) + { + + *(float4*)(B_shared_ptr + (ax0_ax1_fused_0_1 * 512)) = // ax0_ax1_fused_0_1 * elements loaded in each loop + *(float4*)(B_ptr + ((k_0 * 32) + (ax0_ax1_fused_0_1 * 8)) * N); + + } + + __syncthreads(); + #pragma unroll + for (int k_1 = 0; k_1 < ( 32 / 4); ++k_1) + { + #pragma unroll + for (int k_2 = 0; k_2 < 4; ++k_2) + { + int vk_in_block = (k_1 << 2) + k_2; + #pragma unroll + for (int i = 0; i < 64; ++i) + { + C_local[i] = C_local[i] + + A_shared_reduce_ptr[((i / 4) * 8) * 32 + vk_in_block] + * B_shared_reduce_ptr[(vk_in_block * 64) + ((i % 4) * 16)]; + + } + } + } + } + } + + #pragma unroll + for (int i = 0; i < 64; ++i) + { + int location_cur = location_offset + ((i / 4) * 8); + int vn = C_n_offset + ((i % 4) * 16); + if (location_cur < M) + C[location_cur * N + vn] = C_local[i]; + } +} + + +void conv_forward_implicit_gemm_cuda( + const phi::GPUContext& dev_ctx, + const phi::DenseTensor& _in_feats, + const phi::DenseTensor& _kernel, + const phi::DenseTensor& _out_in_map, + int num_out_feats, int num_out_channels, + phi::DenseTensor& _out_feats) +{ + auto compute_capability = dev_ctx.GetComputeCapability(); + bool allow_fp16 = compute_capability >= 75; + bool is_half = _in_feats.dtype() == phi::DataType::FLOAT16; + + int num_in_feats = _in_feats.dims()[0]; + int num_in_channels = _in_feats.dims()[1]; + + int kernel_volume = _out_in_map.dims()[1]; + auto out_in_map = const_cast(_out_in_map.data()); + + if (is_half) + { + if (!allow_fp16) + { + throw std::runtime_error("FP16 kernels are not supported for implicit GEMM now for SM75-."); + } + auto in_feats = reinterpret_cast(const_cast(_in_feats.data())); + auto kernel = reinterpret_cast(const_cast(_kernel.data())); + auto out_feats = reinterpret_cast(_out_feats.data()); + + if (num_out_channels % 64 == 0 && num_in_channels % 32 == 0) + { + int j_factors1 = num_out_channels / 16 / 4; + dim3 num_blocks((num_out_feats + 127) / 128 * j_factors1); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 4); + conv_forward_cuda_setting3_mode0_f16f16f32<<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_in_channels % 32 == 0 && num_out_channels % 16 == 0) + { + int j_factors1 = num_out_channels / 16 / 1; + dim3 num_blocks((num_out_feats + 127) / 128 * j_factors1); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + conv_forward_cuda_setting2_mode0_f16f16f32<<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + // throw std::invalid_argument("IC is too small for this kernel"); + int j_factors1 = (num_out_channels + 15) / 16 / 1; + dim3 num_blocks((num_out_feats + 127) / 128 * j_factors1); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + if (num_in_channels % 16 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 16, false, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 8 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 16, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 8, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 4, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 2, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else if (num_in_channels % 8 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 8 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f16f16f32<16, 2, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else if (num_in_channels % 4 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<8, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 8 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<8, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<8, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<8, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f16f16f32<8, 2, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else if (num_in_channels % 2 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<4, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 8 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<4, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<4, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<4, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f16f16f32<4, 2, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<2, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 8 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<2, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<2, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f16f16f32<2, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f16f16f32<2, 2, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + } + } + else // fp32fp32fp32 + { + auto in_feats = const_cast(_in_feats.data()); + auto kernel = const_cast(_kernel.data()); + auto out_feats = _out_feats.data(); + + if (num_out_channels % 64 == 0 && num_in_channels % 32 == 0) + { + int block_num_M = (num_out_feats + 127) / 128; + int block_num_N = num_out_channels / 64; //j_factors1 + dim3 num_blocks(block_num_M * block_num_N); + dim3 threads_per_block(128); + conv_forward_cuda_setting3_mode0_f32f32f32<<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_in_channels % 32 == 0 && num_out_channels % 16 == 0) + { + int block_num_M = (num_out_feats + 127) / 128; + int block_num_N = num_out_channels / 16; //j_factors1 + dim3 num_blocks(block_num_M * block_num_N); + dim3 threads_per_block(64); + conv_forward_cuda_setting2_mode0_f32f32f32<<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + int block_num_M = (num_out_feats + 127) / 128; + int block_num_N = (num_out_channels + 15) / 16; //j_factors1 + dim3 num_blocks(block_num_M * block_num_N); + dim3 threads_per_block(64); + + if (num_in_channels % 16 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 16, false, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 16, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 8, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 4, false, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else if (num_in_channels % 4 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f32f32f32<16, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else if (num_in_channels % 2 == 0) + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<8, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<8, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<8, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f32f32f32<8, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + else + { + if (num_out_channels % 16 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<4, 16, true, false><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 4 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<4, 16, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else if (num_out_channels % 2 == 0) + { + conv_forward_cuda_setting1_mode0_f32f32f32<4, 8, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + else + { + conv_forward_cuda_setting1_mode0_f32f32f32<4, 4, true, true><<>>( + _out_feats.dims()[0], num_in_channels, num_out_channels, kernel_volume, in_feats, kernel, out_in_map, out_feats); + } + } + } + } +} diff --git a/paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh b/paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh new file mode 100644 index 00000000000000..c9024e06b2b9af --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh @@ -0,0 +1,95 @@ +#pragma once + +template +struct global_load; + +template <> +struct global_load<16> +{ + __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) + { + uint4 &data = *reinterpret_cast(&D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)(pred_guard & 1)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + } +}; + +template <> +struct global_load<8> +{ + __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) + { + uint2 const *ptr_ldg = reinterpret_cast(ptr); +#pragma unroll + for (int ldg_idx = 0; ldg_idx < 2; ldg_idx++) + { + uint2 &data = *(reinterpret_cast(&D) + ldg_idx); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" + " @p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "r"(data.x), "r"(data.y)); + } + } +}; + +template <> +struct global_load<4> +{ + __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) + { + unsigned const *ptr_ldg = reinterpret_cast(ptr); +#pragma unroll + for (int ldg_idx = 0; ldg_idx < 4; ldg_idx++) + { + unsigned &data = *(reinterpret_cast(&D) + ldg_idx); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" + " @p ld.global.u32 %0, [%1];\n" + "}\n" + : "=r"(data) + : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "r"(data)); + } + } +}; + +template <> +struct global_load<2> +{ + __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) + { + uint16_t const *ptr_ldg = reinterpret_cast(ptr); +#pragma unroll + for (int ldg_idx = 0; ldg_idx < 8; ldg_idx++) + { + uint16_t &data = *(reinterpret_cast(&D) + ldg_idx); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" + " @p ld.global.u16 %0, [%1];\n" + "}\n" + : "=h"(data) + : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "h"(data)); + } + } +}; diff --git a/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu index 47daa1eae19eda..4b9337d5d6deb9 100644 --- a/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu @@ -59,6 +59,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx, phi::AddKernel( dev_ctx, x.values(), y.values(), out->mutable_values()); out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } template diff --git a/paddle/phi/kernels/sparse/gpu/sparse_conv_hashmap.cuh b/paddle/phi/kernels/sparse/gpu/sparse_conv_hashmap.cuh new file mode 100644 index 00000000000000..73ad53de502da5 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/sparse_conv_hashmap.cuh @@ -0,0 +1,294 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/funcs/math_function_impl.h" + +/** Reserved value for indicating "empty". */ +#define EMPTY_CELL (0) +/** CUDA naive thread block size. */ +#define BLOCK_SIZE (256) + +__inline__ __device__ int8_t atomicCAS(int8_t* address, int8_t compare, int8_t val) { + int32_t* base_address = (int32_t*)((char*)address - ((size_t)address & 3)); + int32_t int_val = (int32_t)val << (((size_t)address & 3) * 8); + int32_t int_comp = (int32_t)compare << (((size_t)address & 3) * 8); + return (int8_t)atomicCAS(base_address, int_comp, int_val); +} + +// TODO: can we do this more efficiently? +__inline__ __device__ int16_t atomicCAS(int16_t* address, int16_t compare, int16_t val) { + int32_t* base_address = (int32_t*)((char*)address - ((size_t)address & 2)); + int32_t int_val = (int32_t)val << (((size_t)address & 2) * 8); + int32_t int_comp = (int32_t)compare << (((size_t)address & 2) * 8); + return (int16_t)atomicCAS(base_address, int_comp, int_val); +} + +__inline__ __device__ int64_t atomicCAS(int64_t* address, int64_t compare, int64_t val) { + return (int64_t)atomicCAS((unsigned long long*)address, (unsigned long long)compare, + (unsigned long long)val); +} + +template +__device__ uint64_t hash_func_64b(dtype* data, int n=4){ + uint64_t hash = 14695981039346656037UL; + for (int j = 0; j < n; j++) { + hash ^= (unsigned int)data[j]; + hash *= 1099511628211UL; + } + // hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); + return hash; +} + +template +__device__ int hash(key_type key, int _capacity){ + return (uint64_t)key % _capacity; +} + +template +class GPUHashTable { + private: + //public: + bool free_pointers; + const int _capacity; + const int _divisor; + const int _width; + key_type* table_keys; + val_type* table_vals; + void insert_many_coords(const phi::GPUContext& dev_ctx, const int *coords, const int n); + void lookup_many_coords(const phi::GPUContext& dev_ctx, const int *coords, val_type *results, + const int* kernel_sizes, const int* tensor_strides, + const int n, const int kernel_volume); + public: + GPUHashTable(phi::DenseTensor* table_keys, phi::DenseTensor* table_vals, const int divisor, const int width) + : _capacity(table_keys->dims()[0]), free_pointers(false), table_keys(table_keys->data()), + table_vals(table_vals->data()), _divisor(divisor), _width(width){}; + ~GPUHashTable() { + if(free_pointers){ + cudaFree(table_keys); + cudaFree(table_vals); + } + }; + void insert_coords(const phi::GPUContext& dev_ctx, const phi::DenseTensor& coords); + void lookup_coords(const phi::GPUContext& dev_ctx, const phi::DenseTensor& coords, const int* kernel_sizes, const int* tensor_strides, int kernel_volume, phi::DenseTensor* results); + int get_divisor(){return _divisor;} + int get_capacity(){return _capacity;} +}; + +using hashtable = GPUHashTable; +using hashtable32 = GPUHashTable; + +template +__global__ void insert_coords_kernel(key_type* table_keys, val_type* table_vals, const int* coords, int n, int _capacity, int _width) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) + { + key_type key = (key_type)(hash_func_64b(coords + idx*_width, _width)); + int value = idx + 1; + int slot = hash(key, _capacity); + while (true) + { + key_type prev = atomicCAS(&table_keys[slot], EMPTY_CELL, key); + if (prev == EMPTY_CELL || prev == key) + { + table_vals[slot] = value; + return; + } + slot = (slot + 1) % _capacity; + } + } +} + + +template +__global__ void lookup_coords_kernel( + key_type* table_keys, val_type* table_vals, const int* coords, val_type* vals, + const int* kernel_sizes, const int* strides, + int n, int _capacity, int kernel_volume, int _width) +{ + int tidx = blockIdx.x * blockDim.x + threadIdx.x; + int idx = tidx / kernel_volume; + int _kernel_idx = tidx % kernel_volume; + int kernel_idx = _kernel_idx; + const int* in_coords = coords + _width * idx; + int coords_out[4]; + //coords_out[2] = in_coords[2]; + //coords_out[3] = in_coords[3]; + coords_out[0] = in_coords[0]; + + if constexpr (odd) + { + #pragma unroll + for(int i = 0; i <= _width-2; i++){ + int cur_offset = _kernel_idx % kernel_sizes[i]; + cur_offset -= (kernel_sizes[i] - 1) / 2; + coords_out[i+1] = in_coords[i+1] * strides[i] + cur_offset; + _kernel_idx /= kernel_sizes[i]; + } + } + else + { + #pragma unroll + for(int i = _width-2; i >= 0; i--){ + int cur_offset = _kernel_idx % kernel_sizes[i]; + cur_offset -= (kernel_sizes[i] - 1) / 2; + coords_out[i+1] = in_coords[i+1] * strides[i] + cur_offset; + _kernel_idx /= kernel_sizes[i]; + } + } + + if (idx < n) + { + key_type key = (key_type)(hash_func_64b(coords_out, _width)); + int slot = hash(key, _capacity); + + while (true) + { + key_type cur_key = table_keys[slot]; + if (key == cur_key) + { + vals[idx * kernel_volume + kernel_idx] = table_vals[slot] - 1; // need to subtract 1 to avoid extra operations in python + } + if (table_keys[slot] == EMPTY_CELL) + { + return; + } + slot = (slot + 1) % _capacity; + } + } +} + +template +void GPUHashTable::insert_many_coords(const phi::GPUContext& dev_ctx, const int *coords, const int n){ + insert_coords_kernel<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, dev_ctx.stream()>>>(table_keys, table_vals, coords, n, _capacity, _width); +} + +template +void GPUHashTable::insert_coords(const phi::GPUContext& dev_ctx, const phi::DenseTensor& coords){ + insert_many_coords(dev_ctx, coords.data(), coords.dims()[0]); +} + +template +void GPUHashTable::lookup_many_coords( + const phi::GPUContext& dev_ctx, + const int* coords, val_type* results, + const int* kernel_sizes, const int* strides, + const int n, const int kernel_volume){ + if (kernel_volume % 2) + lookup_coords_kernel<<<(n * kernel_volume + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, dev_ctx.stream()>>>( + table_keys, table_vals, coords, results, kernel_sizes, strides, + n, _capacity, kernel_volume, _width); + else + lookup_coords_kernel<<<(n * kernel_volume + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, dev_ctx.stream()>>>( + table_keys, table_vals, coords, results, kernel_sizes, strides, + n, _capacity, kernel_volume, _width); +} + +template +void GPUHashTable::lookup_coords( + const phi::GPUContext& dev_ctx, + const phi::DenseTensor& coords, + const int* kernel_sizes, + const int* strides, + const int kernel_volume, + phi::DenseTensor* results){ + int32_t* results_data = results->data(); + lookup_many_coords(dev_ctx, coords.data(), results_data, kernel_sizes, strides, coords.dims()[0], kernel_volume); +} + +template +void build_sparse_conv_kmap( + const phi::GPUContext& dev_ctx, + const phi::SparseCooTensor& x, + const std::string& key, + const std::vector& kernel_sizes, + const std::vector& strides, + const int kernel_volume, + const bool is2D, + phi::SparseCooTensor* out) +{ + int nnz = x.nnz(); + const phi::KmapCache* in_kmap_cache_ptr = x.GetKmapCache(key); + out->ClearKmaps(); + phi::KmapCache* out_kmap_cache_ptr = nullptr; + bool to_insert = false; + if (in_kmap_cache_ptr == nullptr) + { + phi::KmapCache kmap_cache; + out_kmap_cache_ptr = out->SetKmapCache(key, kmap_cache); + if (out_kmap_cache_ptr->hashmap_keys == nullptr) { + phi::DenseTensor* tmp_hashmap_keys = new phi::DenseTensor(); + tmp_hashmap_keys->Resize({2 * x.nnz()}); + dev_ctx.template Alloc(tmp_hashmap_keys); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, tmp_hashmap_keys, static_cast(0)); + out_kmap_cache_ptr->hashmap_keys = tmp_hashmap_keys; + to_insert = true; + } + if (out_kmap_cache_ptr->hashmap_values == nullptr) { + phi::DenseTensor* tmp_hashmap_values = new phi::DenseTensor(); + tmp_hashmap_values->Resize({2 * x.nnz()}); + dev_ctx.template Alloc(tmp_hashmap_values); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, tmp_hashmap_values, static_cast(0)); + out_kmap_cache_ptr->hashmap_values = tmp_hashmap_values; + } + + if (out_kmap_cache_ptr->coords == nullptr) { + phi::DenseTensor* tmp_indices = new phi::DenseTensor(); + tmp_indices->Resize({x.indices().dims()[1], x.indices().dims()[0]}); + dev_ctx.template Alloc(tmp_indices); + // transpose indices + std::vector perm = {1, 0}; + phi::funcs::TransposeGPUKernelDriver(dev_ctx, x.indices(), perm, tmp_indices); + out_kmap_cache_ptr->coords = tmp_indices; + } + + const int divisor = 128; + const int width = is2D ? 3 : 4; + auto hashmap = GPUHashTable(out_kmap_cache_ptr->hashmap_keys, out_kmap_cache_ptr->hashmap_values, divisor, width); + if (to_insert) { + hashmap.insert_coords(dev_ctx, *(out_kmap_cache_ptr->coords)); + } + + phi::DenseTensor* tmp_out_in_map = new phi::DenseTensor(); + tmp_out_in_map->Resize({(x.nnz() + divisor - 1) / divisor * divisor, kernel_volume}); + dev_ctx.template Alloc(tmp_out_in_map); + out_kmap_cache_ptr->out_in_map = tmp_out_in_map; + phi::funcs::SetConstant set_neg_one; + set_neg_one(dev_ctx, out_kmap_cache_ptr->out_in_map, static_cast(-1)); + + + // need to put kernel_sizes and strides to GPU + auto kernel_sizes_tensor = phi::Empty(dev_ctx, {3}); + phi::TensorFromVector(kernel_sizes, dev_ctx, &kernel_sizes_tensor); + auto strides_tensor = phi::Empty(dev_ctx, {3}); + phi::TensorFromVector(strides, dev_ctx, &strides_tensor); + + hashmap.lookup_coords( + dev_ctx, *(out_kmap_cache_ptr->coords), kernel_sizes_tensor.data(), strides_tensor.data(), kernel_volume, out_kmap_cache_ptr->out_in_map); + + } else { + // out tensor takes the kmaps from x + out->SetKmaps(x.GetKmaps()); + // force clear the kmaps of x + const_cast(x).ClearKmaps(); + } + const phi::KmapCache* new_out_kmap_cache_ptr = out->GetKmapCache(key); + assert(new_out_kmap_cache_ptr != nullptr); + assert(new_out_kmap_cache_ptr->hashmap_keys != nullptr); + assert(new_out_kmap_cache_ptr->hashmap_values != nullptr); + assert(new_out_kmap_cache_ptr->coords != nullptr); + assert(new_out_kmap_cache_ptr->out_in_map != nullptr); + return; +} diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 32fe4ae07ab67e..84cd885f862f0d 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -38,6 +38,7 @@ namespace sparse { phi::prefix##Kernel( \ dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \ out->SetIndicesDict(x.GetIndicesDict()); \ + out->SetKmaps(x.GetKmaps()); \ } \ \ template \ @@ -107,6 +108,7 @@ void ScaleCooKernel(const Context& dev_ctx, bias_after_scale, out->mutable_non_zero_elements()); out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } template @@ -157,6 +159,7 @@ void CastCooKernel(const Context& dev_ctx, phi::CastKernel(dev_ctx, x_values, value_dtype, out_values); } out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } template @@ -218,6 +221,7 @@ void IsnanCooKernel(const Context& dev_ctx, phi::IsnanKernel( dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); out->SetIndicesDict(x.GetIndicesDict()); + out->SetKmaps(x.GetKmaps()); } template diff --git a/python/paddle/sparse/nn/functional/__init__.py b/python/paddle/sparse/nn/functional/__init__.py index 5fc68de914bd50..93511f0972e9fc 100644 --- a/python/paddle/sparse/nn/functional/__init__.py +++ b/python/paddle/sparse/nn/functional/__init__.py @@ -13,7 +13,14 @@ # limitations under the License. from .activation import leaky_relu, relu, relu6, softmax -from .conv import conv2d, conv3d, subm_conv2d, subm_conv3d +from .conv import ( + conv2d, + conv3d, + subm_conv2d, + subm_conv2d_igemm, + subm_conv3d, + subm_conv3d_igemm, +) from .pooling import max_pool3d from .transformer import attention @@ -21,7 +28,9 @@ 'conv2d', 'conv3d', 'subm_conv2d', + 'subm_conv2d_igemm', 'subm_conv3d', + 'subm_conv3d_igemm', 'max_pool3d', 'relu', 'relu6', diff --git a/python/paddle/sparse/nn/functional/conv.py b/python/paddle/sparse/nn/functional/conv.py index b26faa9431d0e3..da961a1417ab29 100644 --- a/python/paddle/sparse/nn/functional/conv.py +++ b/python/paddle/sparse/nn/functional/conv.py @@ -192,6 +192,174 @@ def _conv2d( return pre_bias +def _conv3d_igemm( + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + subm=False, + key=None, + data_format="NDHWC", + name=None, +): + assert groups == 1, "Currently, only support groups=1" + assert subm is True, "Currently, only support subm=True for implicit gemm" + + dims = 3 + + # Currently, only support 'NDHWC' + if data_format not in ["NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NDHWC'. Received " + f"Attr(data_format): {data_format}." + ) + if len(x.shape) != 5: + raise ValueError( + f"Input x should be 5D tensor, but received x with the shape of {x.shape}" + ) + + channel_last = data_format == "NDHWC" + channel_dim = -1 if channel_last else 1 + if len(x.shape) != 5: + raise ValueError( + f"Input x should be 5D tensor, but received x with the shape of {x.shape}" + ) + num_channels = x.shape[channel_dim] + if num_channels < 0: + raise ValueError( + f"The channel dimension of the input({x.shape}) should be defined. " + f"Received: {num_channels}." + ) + + padding, padding_algorithm = _update_padding_nd(padding, channel_last, dims) + stride = convert_to_list(stride, dims, 'stride') + dilation = convert_to_list(dilation, dims, 'dilation') + + if in_dynamic_mode(): + pre_bias = _C_ops.sparse_conv3d_implicit_gemm( + x, + weight, + padding, + dilation, + stride, + groups, + subm, + key if key is not None else "", + ) + if bias is not None: + return add(pre_bias, bias) + else: + return pre_bias + else: + inputs = {'x': x, 'kernel': weight} + attrs = { + 'paddings': padding, + 'dilations': dilation, + 'strides': stride, + 'groups': groups, + 'subm': subm, + 'key': key, + } + op_type = 'sparse_conv3d_implicit_gemm' + helper = LayerHelper(op_type, **locals()) + pre_bias = helper.create_sparse_variable_for_type_inference(x.dtype) + outputs = {"out": pre_bias} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs + ) + if bias is not None: + return add(pre_bias, bias) + else: + return pre_bias + + +def _conv2d_igemm( + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + subm=False, + key=None, + data_format="NHWC", + name=None, +): + assert groups == 1, "Currently, only support groups=1" + assert subm is True, "Currently, only support subm=True for implicit gemm" + + dims = 2 + + # Currently, only support 'NDHWC' + if data_format not in ["NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NHWC'. Received " + f"Attr(data_format): {data_format}." + ) + if len(x.shape) != 4: + raise ValueError( + f"Input x should be 5D tensor, but received x with the shape of {x.shape}" + ) + + channel_last = data_format == "NHWC" + channel_dim = -1 if channel_last else 1 + if len(x.shape) != 4: + raise ValueError( + f"Input x should be 4D tensor, but received x with the shape of {x.shape}" + ) + num_channels = x.shape[channel_dim] + if num_channels < 0: + raise ValueError( + f"The channel dimension of the input({x.shape}) should be defined. " + f"Received: {num_channels}." + ) + + padding, padding_algorithm = _update_padding_nd(padding, channel_last, dims) + stride = convert_to_list(stride, dims, 'stride') + dilation = convert_to_list(dilation, dims, 'dilation') + + if in_dynamic_mode(): + pre_bias = _C_ops.sparse_conv3d_implicit_gemm( + x, + weight, + padding, + dilation, + stride, + groups, + subm, + key if key is not None else "", + ) + if bias is not None: + return add(pre_bias, bias) + else: + return pre_bias + else: + inputs = {'x': x, 'kernel': weight} + attrs = { + 'paddings': padding, + 'dilations': dilation, + 'strides': stride, + 'groups': groups, + 'subm': subm, + 'key': key, + } + op_type = 'sparse_conv3d_implicit_gemm' + helper = LayerHelper(op_type, **locals()) + pre_bias = helper.create_sparse_variable_for_type_inference(x.dtype) + outputs = {"out": pre_bias} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs + ) + if bias is not None: + return add(pre_bias, bias) + else: + return pre_bias + + def conv3d( x, weight, @@ -410,6 +578,118 @@ def subm_conv3d( ) +def subm_conv3d_igemm( + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + data_format="NDHWC", + key=None, + name=None, +): + r""" + + The sparse submanifold convolution3d functional calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, D, H, W, C]` . Where N is batch size, C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = W \ast X + b + + In the above equation: + + * :math:`X`: Input value, a tensor with NCDHW or NDHWC format. + * :math:`W`: Filter value, a tensor with DHWCM format. + * :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Args: + x (Tensor): The input is 5-D SparseCooTensor with shape [N, D, H, W, C], the data + type of input is float16 or float32 or float64. + weight (Tensor): The convolution kernel, a Tensor with shape [kD, kH, kW, C/g, M], + where M is the number of filters(output channels), g is the number of groups, + kD, kH, kW are the filter's depth, height and width respectively. + bias (Tensor, optional): The bias, a Tensor of shape [M]. + stride (int|list|tuple, optional): The stride size. It means the stride in convolution. If stride is a + list/tuple, it must contain three integers, (stride_depth, stride_height, stride_width). + Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. + padding (string|int|list|tuple): The padding size. It means the number of zero-paddings + on both sides for each dimension. If `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + dilation (int|list|tuple, optional): The dilation size. It means the spacing between the kernel points. + If dilation is a list/tuple, it must contain three integers, (dilation_depth, dilation_height, + dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. + Default: dilation = 1. + groups (int, optional): The groups number of the Conv3D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Currently, only support groups=1. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of: + `[batch_size, input_depth, input_height, input_width, input_channels]`. + key(str, optional): the key is used to save or use the same rulebook, + the definition and role of rulebook refers to + https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The + default value is None. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + A SparseCooTensor representing the conv3d, whose data type is + the same with input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + >>> values = [[1], [2], [3], [4]] + >>> indices = paddle.to_tensor(indices, dtype='int32') + >>> values = paddle.to_tensor(values, dtype='float32') + >>> dense_shape = [1, 1, 3, 4, 1] + >>> sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + >>> weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') + >>> y = paddle.sparse.nn.functional.subm_conv3d(sparse_x, weight) + >>> print(y.shape) + [1, 1, 3, 4, 1] + """ + return _conv3d_igemm( + x, + weight, + bias, + stride, + padding, + dilation, + groups, + True, + key, + data_format, + name, + ) + + def conv2d( x, weight, @@ -621,3 +901,112 @@ def subm_conv2d( data_format, name, ) + + +def subm_conv2d_igemm( + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + data_format="NHWC", + key=None, + name=None, +): + r""" + + The sparse submanifold convolution2d functional calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, H, W, C]` . Where N is batch size, C is the number of + channels, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = \sigma (W \ast X + b) + + In the above equation: + + * :math:`X`: Input value, a tensor with NHWC format. + * :math:`W`: Filter value, a tensor with HWCM format. + * :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Args: + x (Tensor): The input is 4-D SparseCooTensor with shape [N, H, W, C], the data + type of input is float16 or float32 or float64. + weight (Tensor): The convolution kernel, a Tensor with shape [kH, kW, C/g, M], + where M is the number of filters(output channels), g is the number of groups, + kD, kH, kW are the filter's height and width respectively. + bias (Tensor, optional): The bias, a Tensor of shape [M]. + stride (int|list|tuple, optional): The stride size. It means the stride in convolution. If stride is a + list/tuple, it must contain two integers, (stride_height, stride_width). + Otherwise, stride_height = stride_width = stride. Default: stride = 1. + padding (string|int|list|tuple, optional): The padding size. It means the number of zero-paddings + on both sides for each dimension. If `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + when `data_format` is `"NHWC"`, `padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + dilation (int|list|tuple, optional): The dilation size. It means the spacing between the kernel points. + If dilation is a list/tuple, it must contain two integers, (dilation_height, + dilation_width). Otherwise, dilation_height = dilation_width = dilation. + Default: dilation = 1. + groups (int, optional): The groups number of the Conv2D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. Currently, only support groups=1. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NHWC"`. + The default is `"NHWC"`. When it is `"NHWC"`, the data is stored in the order of: + `[batch_size, input_height, input_width, input_channels]`. + key(str, optional): the key is used to save or use the same rulebook, + the definition and role of rulebook refers to + https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The + default value is None. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + A SparseCooTensor representing the conv2d, whose data type is the same with input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + >>> values = [[1], [2], [3], [4]] + >>> indices = paddle.to_tensor(indices, dtype='int32') + >>> values = paddle.to_tensor(values, dtype='float32') + >>> dense_shape = [1, 3, 4, 1] + >>> sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + >>> weight = paddle.randn((3, 3, 1, 1), dtype='float32') + >>> y = paddle.sparse.nn.functional.subm_conv2d(sparse_x, weight) + >>> print(y.shape) + [1, 3, 4, 1] + """ + return _conv2d_igemm( + x, + weight, + bias, + stride, + padding, + dilation, + groups, + True, + key, + data_format, + name, + ) diff --git a/python/paddle/sparse/nn/layer/conv.py b/python/paddle/sparse/nn/layer/conv.py index 62cf355de2e3dc..f38b3c64593f22 100644 --- a/python/paddle/sparse/nn/layer/conv.py +++ b/python/paddle/sparse/nn/layer/conv.py @@ -40,6 +40,7 @@ def __init__( weight_attr=None, bias_attr=None, data_format="NDHWC", + backend=None, ): super().__init__() assert ( @@ -53,11 +54,16 @@ def __init__( self._data_format = data_format self._subm = subm self._key = key + self._backend = backend assert ( padding_mode == 'zeros' ), "Currently, only support padding_mode='zeros'" assert groups == 1, "Currently, only support groups=1" + assert backend in [ + None, + 'igemm', + ], "The value of 'backend' in Conv3D should be None or 'igemm'." valid_format = {'NDHWC'} if data_format not in valid_format: @@ -98,18 +104,36 @@ def _get_default_param_initializer(): ) def forward(self, x): - out = F.conv._conv3d( - x, - self.weight, - bias=self.bias, - stride=self._stride, - padding=self._updated_padding, - dilation=self._dilation, - groups=self._groups, - subm=self._subm, - key=self._key, - data_format=self._data_format, - ) + if self._backend is None: + out = F.conv._conv3d( + x, + self.weight, + bias=self.bias, + stride=self._stride, + padding=self._updated_padding, + dilation=self._dilation, + groups=self._groups, + subm=self._subm, + key=self._key, + data_format=self._data_format, + ) + elif self._backend == 'igemm': + out = F.conv._conv3d_igemm( + x, + self.weight, + bias=self.bias, + stride=self._stride, + padding=self._updated_padding, + dilation=self._dilation, + groups=self._groups, + subm=self._subm, + key=self._key, + data_format=self._data_format, + ) + else: + raise ValueError( + f"The value of 'backend' in Conv3D should be None or 'igemm', but got {self._backend}." + ) return out def extra_repr(self): @@ -144,6 +168,7 @@ def __init__( weight_attr=None, bias_attr=None, data_format="NHWC", + backend=None, ): super().__init__() assert ( @@ -157,11 +182,16 @@ def __init__( self._data_format = data_format self._subm = subm self._key = key + self._backend = backend assert ( padding_mode == 'zeros' ), "Currently, only support padding_mode='zeros'" assert groups == 1, "Currently, only support groups=1" + assert backend in [ + None, + 'igemm', + ], "The value of 'backend' in Conv3D should be None or 'igemm'." valid_format = {'NHWC'} if data_format not in valid_format: @@ -202,18 +232,36 @@ def _get_default_param_initializer(): ) def forward(self, x): - out = F.conv._conv2d( - x, - self.weight, - bias=self.bias, - stride=self._stride, - padding=self._updated_padding, - dilation=self._dilation, - groups=self._groups, - subm=self._subm, - key=self._key, - data_format=self._data_format, - ) + if self._backend is None: + out = F.conv._conv2d( + x, + self.weight, + bias=self.bias, + stride=self._stride, + padding=self._updated_padding, + dilation=self._dilation, + groups=self._groups, + subm=self._subm, + key=self._key, + data_format=self._data_format, + ) + elif self._backend == 'igemm': + out = F.conv._conv2d_igemm( + x, + self.weight, + bias=self.bias, + stride=self._stride, + padding=self._updated_padding, + dilation=self._dilation, + groups=self._groups, + subm=self._subm, + key=self._key, + data_format=self._data_format, + ) + else: + raise ValueError( + f"The value of 'backend' in Conv2D should be None or 'igemm', but got {self._backend}." + ) return out def extra_repr(self): @@ -624,6 +672,7 @@ def __init__( weight_attr=None, bias_attr=None, data_format="NDHWC", + backend=None, ): super().__init__( in_channels, @@ -639,6 +688,7 @@ def __init__( weight_attr=weight_attr, bias_attr=bias_attr, data_format=data_format, + backend=backend, ) @@ -764,6 +814,7 @@ def __init__( weight_attr=None, bias_attr=None, data_format="NHWC", + backend=None, ): super().__init__( in_channels, @@ -779,4 +830,5 @@ def __init__( weight_attr=weight_attr, bias_attr=bias_attr, data_format=data_format, + backend=backend, ) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 63d84ece4aa988..079232a4f30211 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -506,6 +506,7 @@ if(NOT WITH_GPU OR WIN32 OR APPLE) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) + list(REMOVE_ITEM TEST_OPS test_sparse_conv_igemm_op) endif() if(NOT WITH_CUDNN_FRONTEND) diff --git a/test/legacy_test/test_sparse_conv_igemm_op.py b/test/legacy_test/test_sparse_conv_igemm_op.py new file mode 100644 index 00000000000000..797f2d6ff84479 --- /dev/null +++ b/test/legacy_test/test_sparse_conv_igemm_op.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np + +import paddle +from paddle import sparse +from paddle.base import core + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO +) +logger = logging.getLogger(__name__) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "only test when CUDA is available", +) +class TestSparseConvImplicitGemm(unittest.TestCase): + def test_SubmConv2D_igemm_forward(self): + indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 3, 4, 1] + correct_out_values = [[4], [5], [10], [7]] + sparse_input = paddle.sparse.sparse_coo_tensor( + indices, values, dense_shape, False + ) + + subm_conv2d = paddle.sparse.nn.SubmConv2D( + 1, + 1, + 3, + padding=1, + stride=1, + data_format='NHWC', + key='subm_conv_2d', + backend='igemm', + ) + # set weight to all ones + subm_conv2d.weight = paddle.create_parameter( + (3, 3, 1, 1), + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(value=1.0), + ) + + sparse_out = subm_conv2d(sparse_input) + # the output shape of subm_conv is same as input shape + np.testing.assert_array_equal(indices, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values, sparse_out.values().numpy() + ) + + def test_SubmConv3D_igemm_forward(self): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values = [[4], [5], [10], [7]] + sparse_input = paddle.sparse.sparse_coo_tensor( + indices, values, dense_shape, False + ) + + subm_conv3d = paddle.sparse.nn.SubmConv3D( + 1, + 1, + (1, 3, 3), + padding=1, + stride=1, + data_format='NDHWC', + key='subm_conv', + backend='igemm', + ) + # set weight to all ones + subm_conv3d.weight = paddle.create_parameter( + (1, 3, 3, 1, 1), + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(value=1.0), + ) + + sparse_out = subm_conv3d(sparse_input) + # the output shape of subm_conv is same as input shape + np.testing.assert_array_equal(indices, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values, sparse_out.values().numpy() + ) + + def test_submconv2d_igemm_forward(self): + indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 3, 4, 1] + correct_out_values = [[5], [6], [11], [8]] + sparse_input = paddle.sparse.sparse_coo_tensor( + indices, values, dense_shape, False + ) + + weight = paddle.ones((3, 3, 1, 1), dtype='float32') + bias = paddle.ones((1), dtype='float32') + sparse_out = paddle.sparse.nn.functional.subm_conv2d_igemm( + sparse_input, + weight, + bias, + stride=1, + padding=1, + dilation=1, + groups=1, + data_format="NHWC", + key='subm_conv_2d', + ) + + # the output shape of subm_conv is same as input shape + np.testing.assert_array_equal(indices, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values, sparse_out.values().numpy() + ) + + def test_submconv3d_igemm_forward(self): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values = [[5], [6], [11], [8]] + sparse_input = paddle.sparse.sparse_coo_tensor( + indices, values, dense_shape, False + ) + + weight = paddle.ones((1, 3, 3, 1, 1), dtype='float32') + bias = paddle.ones((1), dtype='float32') + sparse_out = paddle.sparse.nn.functional.subm_conv3d_igemm( + sparse_input, + weight, + bias, + stride=1, + padding=1, + dilation=1, + groups=1, + data_format="NDHWC", + key='subm_conv_3d', + ) + + # the output shape of subm_conv is same as input shape + np.testing.assert_array_equal(indices, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values, sparse_out.values().numpy() + ) + + def test_multi_input(self): + indices_1 = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + indices_2 = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [0, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices_1 = paddle.to_tensor(indices_1, dtype='int32') + indices_2 = paddle.to_tensor(indices_2, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values_1 = [[4], [5], [10], [7]] + correct_out_values_2 = [[1], [5], [9], [7]] + sparse_input_1 = paddle.sparse.sparse_coo_tensor( + indices_1, values, dense_shape, False + ) + sparse_input_2 = paddle.sparse.sparse_coo_tensor( + indices_2, values, dense_shape, False + ) + + subm_conv3d = paddle.sparse.nn.SubmConv3D( + 1, + 1, + (1, 3, 3), + padding=1, + stride=1, + data_format='NDHWC', + key='subm_conv', + backend='igemm', + ) + # set weight to all ones + subm_conv3d.weight = paddle.create_parameter( + (1, 3, 3, 1, 1), + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(value=1.0), + ) + + sparse_out = subm_conv3d(sparse_input_1) + np.testing.assert_array_equal(indices_1, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values_1, sparse_out.values().numpy() + ) + + sparse_out = subm_conv3d(sparse_input_2) + + # the output shape of subm_conv is same as input shape + np.testing.assert_array_equal(indices_2, sparse_out.indices().numpy()) + np.testing.assert_array_equal( + correct_out_values_2, sparse_out.values().numpy() + ) + + +class TestStatic(unittest.TestCase): + def test3d(self): + paddle.enable_static() + main = paddle.static.Program() + with paddle.static.program_guard(main): + indices = paddle.static.data( + name='indices', shape=[4, 4], dtype='int32' + ) + values = paddle.static.data( + name='values', shape=[4, 1], dtype='float32' + ) + dense_shape = [1, 1, 3, 4, 1] + sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape) + + weight_shape = [1, 3, 3, 1, 1] + weight = paddle.static.data( + name='weight', shape=weight_shape, dtype='float32' + ) + bias_shape = [1] + bias = paddle.static.data( + name='bias', shape=bias_shape, dtype='float32' + ) + out = sparse.nn.functional.subm_conv3d_igemm( + sp_x, + weight, + bias, + stride=1, + padding=1, + dilation=1, + groups=1, + data_format="NDHWC", + ) + sp_out = sparse.nn.functional.relu(out) + out_indices = sp_out.indices() + out_values = sp_out.values() + out = sp_out.to_dense() + + exe = paddle.static.Executor() + + indices_data = [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 1, 2], + [1, 3, 2, 3], + ] + values_data = [[1.0], [2.0], [3.0], [4.0]] + weight_data = np.array( + [[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]] + ).astype('float32') + weight_data = weight_data.reshape(weight_shape) + bias_data = np.array([1]).astype('float32') + + fetch = exe.run( + feed={ + 'indices': indices_data, + 'values': values_data, + 'weight': weight_data, + 'bias': bias_data, + }, + fetch_list=[out, out_indices, out_values], + return_numpy=True, + ) + correct_out_values = [[5.0], [6.0], [11.0], [8.0]] + np.testing.assert_array_equal(correct_out_values, fetch[2]) + paddle.disable_static() + + def test2d(self): + paddle.enable_static() + main = paddle.static.Program() + with paddle.static.program_guard(main): + indices = paddle.static.data( + name='indices', shape=[3, 4], dtype='int32' + ) + values = paddle.static.data( + name='values', shape=[4, 1], dtype='float32' + ) + dense_shape = [1, 3, 4, 1] + sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape) + + weight_shape = [3, 3, 1, 1] + weight = paddle.static.data( + name='weight', shape=weight_shape, dtype='float32' + ) + bias_shape = [1] + bias = paddle.static.data( + name='bias', shape=bias_shape, dtype='float32' + ) + out = sparse.nn.functional.subm_conv2d_igemm( + sp_x, + weight, + bias, + stride=1, + padding=1, + dilation=1, + groups=1, + data_format="NHWC", + ) + sp_out = sparse.nn.functional.relu(out) + out_indices = sp_out.indices() + out_values = sp_out.values() + out = sp_out.to_dense() + + exe = paddle.static.Executor() + + indices_data = [ + [0, 0, 0, 0], + [0, 0, 1, 2], + [1, 3, 2, 3], + ] + values_data = [[1.0], [2.0], [3.0], [4.0]] + weight_data = np.array( + [[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]] + ).astype('float32') + weight_data = weight_data.reshape(weight_shape) + bias_data = np.array([1]).astype('float32') + + fetch = exe.run( + feed={ + 'indices': indices_data, + 'values': values_data, + 'weight': weight_data, + 'bias': bias_data, + }, + fetch_list=[out, out_indices, out_values], + return_numpy=True, + ) + correct_out_values = [[5.0], [6.0], [11.0], [8.0]] + np.testing.assert_array_equal(correct_out_values, fetch[2]) + paddle.disable_static() + + +if __name__ == "__main__": + unittest.main()