diff --git a/CMakeLists.txt b/CMakeLists.txt index 70eb5f11ea168a..a724a60b3ed448 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,7 @@ option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF) option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND}) option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF) option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME" OFF) +option(WITH_CUSPARSELT "Compile PaddlePaddle with CUSPARSELT" OFF) # Note(zhouwei): It use option above, so put here include(init) include(generic) # simplify cmake module diff --git a/cmake/external/cusparselt.cmake b/cmake/external/cusparselt.cmake new file mode 100644 index 00000000000000..606ead5e7569f7 --- /dev/null +++ b/cmake/external/cusparselt.cmake @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if (NOT (WITH_CUSPARSELT AND WITH_TENSORRT)) + return() +endif () + +if (WITH_ARM OR WIN32) + message(SEND_ERROR "The current sparselt support linux only") + return() +endif () + +INCLUDE(ExternalProject) + +SET(CUSPARSELT_PROJECT "extern_cusparselt") +SET(CUSPARSELT_URL "https://developer.download.nvidia.com/compute/libcusparse-lt/0.2.0/local_installers/libcusparse_lt-linux-x86_64-0.2.0.1.tar.gz" CACHE STRING "" FORCE) +SET(CUSPARSELT_PREFIX_DIR ${THIRD_PARTY_PATH}/cusparselt) +SET(CUSPARSELT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cusparselt) +SET(CUSPARSELT_INC_DIR "${CUSPARSELT_INSTALL_DIR}/include" CACHE PATH "sparselt include directory." FORCE) +SET(CUSPARSELT_LIB_DIR "${CUSPARSELT_INSTALL_DIR}/lib64" CACHE PATH "sparselt lib directory." FORCE) +set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) +include_directories(${CUSPARSELT_INC_DIR}) + +ExternalProject_Add( + ${CUSPARSELT_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${CUSPARSELT_URL} + PREFIX ${CUSPARSELT_PREFIX_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/lib64 ${CUSPARSELT_LIB_DIR} && + ${CMAKE_COMMAND} -E copy_directory ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/include ${CUSPARSELT_INC_DIR} + UPDATE_COMMAND "" + ) + +add_library(cusparselt INTERFACE) +add_dependencies(cusparselt ${CUSPARSELT_PROJECT}) +set(CUSPARSELT_FOUND ON) +add_definitions(-DPADDLE_WITH_CUSPARSELT) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index a52047e16167d0..61cd58376c20e8 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -132,6 +132,13 @@ function(copy_part_of_thrid_party TARGET DST) endif() endif() + if (WITH_SPARSELT) + set(dst_dir "${DST}/third_party/install/cusparselt") + copy(${TARGET} + SRCS ${CUSPARSELT_INC_DIR} ${CUSPARSELT_LIB_DIR} + DSTS ${dst_dir} ${dst_dir}) + endif() + set(dst_dir "${DST}/third_party/install/gflags") copy(${TARGET} SRCS ${GFLAGS_INCLUDE_DIR} ${GFLAGS_LIBRARIES} diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index eb6fa4ee13c81e..d551cbe26209af 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -424,4 +424,9 @@ if (WITH_IPU) list(APPEND third_party_deps extern_poplar) endif() +if(WITH_CUSPARSELT) + include(external/cusparselt) # download, build, install cusparselt + list(APPEND third_party_deps extern_cusparselt) +endif() + add_custom_target(third_party ALL DEPENDS ${third_party_deps}) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 3fc938f76410ce..e60fe3d67b5836 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -102,6 +102,7 @@ pass_library(add_support_int8_pass inference) pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(mixed_precision_configure_pass inference) +pass_library(desne_to_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/desne_to_sparse_pass.cc b/paddle/fluid/framework/ir/desne_to_sparse_pass.cc new file mode 100644 index 00000000000000..d2e0823c3323ae --- /dev/null +++ b/paddle/fluid/framework/ir/desne_to_sparse_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/desne_to_sparse_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +ReplaceDenseWithSparsePass::ReplaceDenseWithSparsePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +void ReplaceDenseWithSparsePass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + std::string name_scope = "desne_to_sparse_pass"; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + + patterns::DenseFC dense_fc_pattern(gpd.mutable_pattern(), + "dense_replace_pass"); + dense_fc_pattern(); + int found_dense_fc_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Replace dense fc with sparse_fc."; + + /* if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + }*/ + + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_weights, fc_weights, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, dense_fc_pattern); + + auto *fc_op = fc->Op(); + auto w_name = fc_op->Input("W")[0]; + // recognize sparse op by name + if (w_name.find("sparse_2_4") != w_name.npos) { + // fake op + OpDesc desc(fc_op->Block()); + desc.SetType("sparse_fc"); + desc.SetInput("Input", {fc_input->Name()}); + desc.SetInput("W", {fc_weights->Name()}); + desc.SetInput("Bias", {fc_bias->Name()}); + desc.SetOutput("Out", {fc_out->Name()}); + + // copy all attr + if (fc_op->HasAttr("x_num_col_dims")) { + desc.SetAttr("x_num_col_dims", fc_op->GetAttr("x_num_col_dims")); + } + if (fc_op->HasAttr("in_num_col_dims")) { + desc.SetAttr("in_num_col_dims", fc_op->GetAttr("in_num_col_dims")); + } + desc.SetAttr("activation_type", fc_op->GetAttr("activation_type")); + if (fc_op->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", fc_op->GetAttr("enable_int8")); + } + if (fc_op->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", fc_op->GetAttr("Input_scale")); + } + if (fc_op->HasAttr("support_int8")) { + desc.SetAttr("support_int8", fc_op->GetAttr("support_int8")); + } + if (fc_op->HasAttr("out_threshold")) { + desc.SetAttr("out_threshold", fc_op->GetAttr("out_threshold")); + } + desc.Flush(); + GraphSafeRemoveNodes(g, {fc}); + auto sparse_fc_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(fc_input, sparse_fc_node); + IR_NODE_LINK_TO(fc_weights, sparse_fc_node); + IR_NODE_LINK_TO(fc_bias, sparse_fc_node); + IR_NODE_LINK_TO(sparse_fc_node, fc_out); + found_dense_fc_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_dense_fc_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(desne_to_sparse_pass, + paddle::framework::ir::ReplaceDenseWithSparsePass); diff --git a/paddle/fluid/framework/ir/desne_to_sparse_pass.h b/paddle/fluid/framework/ir/desne_to_sparse_pass.h new file mode 100644 index 00000000000000..33e278d778fbbb --- /dev/null +++ b/paddle/fluid/framework/ir/desne_to_sparse_pass.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Replace dense op with sparse op + */ +class Graph; + +class ReplaceDenseWithSparsePass : public FusePassBase { + public: + ReplaceDenseWithSparsePass(); + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"desne_to_sparse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f7c1a68c826f09..410c899835a7ed 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3398,6 +3398,31 @@ PDNode *patterns::AddSupportInt8::operator()() { return quant_out; } +PDNode *patterns::DenseFC::operator()() { + auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); + // Input + auto *fc_input = pattern->NewNode(fc_input_repr()) + ->AsInput() + ->assert_is_op_input("fc", "Input"); + // Filter + auto *fc_weights = pattern->NewNode(fc_weights_repr()) + ->AsInput() + ->assert_is_op_input("fc", "W"); + // Bias + auto *fc_bias = pattern->NewNode(fc_bias_repr()) + ->AsInput() + ->assert_is_op_input("fc", "Bias"); + // Output + auto *fc_out = pattern->NewNode(fc_out_repr()) + ->AsOutput() + ->assert_is_op_output("fc", "Out") + ->assert_is_only_output_of_op("fc"); + + fc->LinksFrom({fc_input, fc_weights, fc_bias}).LinksTo({fc_out}); + + return fc_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3c6b6ce94e23f3..22a8c07c136942 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1954,6 +1954,23 @@ struct AddSupportInt8 : public PatternBase { PATTERN_DECL_NODE(quant_out); }; +// +// \brief Pattern looking for dense fc. +// +struct DenseFC : public PatternBase { + DenseFC(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dense_fc") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(fc_out); + PATTERN_DECL_NODE(fc_input); + PATTERN_DECL_NODE(fc_weights); + PATTERN_DECL_NODE(fc_bias); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b40377855bd3ea..eb02666a5b969e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1808,6 +1808,9 @@ USE_TRT_CONVERTER(strided_slice) USE_TRT_CONVERTER(transformer_input_convert) USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) +#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) +USE_TRT_CONVERTER(sparse_fc) +#endif #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3b1c84db4c5346..6589ad86f02b88 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -114,6 +114,7 @@ const std::vector kTRTSubgraphPasses({ "remove_padding_recover_padding_pass", // "delete_remove_padding_recover_padding_pass", // // "yolo_box_fuse_pass", // + "desne_to_sparse_pass", "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 05ab3fb53e5333..3f3184eb95d18f 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,64 +1,21 @@ # Add TRT tests +list(APPEND CONVERT_FILES matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc +batch_norm_op.cc activation_op.cc unary_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc +pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc +shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc +emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc +gather_op.cc anchor_generator_op.cc yolo_box_op.cc yolo_box_head_op.cc arg_max_op.cc roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc +multiclass_nms3_op.cc nearest_interp_op.cc reshape_op.cc reduce_op.cc gather_nd_op.cc tile_op.cc +conv3d_op.cc mish_op.cc nearest_interp_v2_op.cc pool3d_op.cc deformable_conv_op.cc preln_emb_eltwise_layernorm.cc +preln_skip_layernorm.cc strided_slice_op.cc roll_op.cc transformer_input_convert_op.cc remove_padding_op.cc +recover_padding_op.cc) + +if (CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) + list(APPEND CONVERT_FILES sparse_fc_op.cc) +endif() + nv_library(tensorrt_converter - SRCS matmul_op.cc - conv2d_op.cc - fc_op.cc - pool2d_op.cc - elementwise_op.cc - batch_norm_op.cc - activation_op.cc - unary_op.cc - softmax_op.cc - concat_op.cc - dropout_op.cc - group_norm_op.cc - pad_op.cc - split_op.cc - prelu_op.cc - leaky_relu_op.cc - gelu_op.cc - layer_norm_op.cc - multihead_matmul_op.cc - shuffle_channel_op.cc - swish_op.cc - instance_norm_op.cc - stack_op.cc - transpose_op.cc - flatten_op.cc - flatten_contiguous_range_op.cc - emb_eltwise_layernorm.cc - skip_layernorm.cc - scale_op.cc - slice_op.cc - hard_sigmoid_op.cc - hard_swish_op.cc - clip_op.cc - gather_op.cc - anchor_generator_op.cc - yolo_box_op.cc - yolo_box_head_op.cc - arg_max_op.cc - roi_align_op.cc - affine_channel_op.cc - multiclass_nms_op.cc - multiclass_nms3_op.cc - nearest_interp_op.cc - reshape_op.cc - reduce_op.cc - gather_nd_op.cc - tile_op.cc - conv3d_op.cc - mish_op.cc - nearest_interp_v2_op.cc - pool3d_op.cc - deformable_conv_op.cc - preln_emb_eltwise_layernorm.cc - strided_slice_op.cc - preln_skip_layernorm.cc - roll_op.cc - transformer_input_convert_op.cc - remove_padding_op.cc - recover_padding_op.cc + SRCS ${CONVERT_FILES} DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc new file mode 100644 index 00000000000000..01c3fad7a3cbec --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc @@ -0,0 +1,319 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * FC converter convert a sparse_fc op in Fluid to a sparse_fc layer in TRT. + */ +class SparseFcOpConverter : public OpConverter { + public: + nvinfer1::ILayer* reshape_before_fc(nvinfer1::ITensor* before_fc, + nvinfer1::Dims x_dim, int x_num_col_dims, + std::string output_name) { + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = x_num_col_dims + 3; + // padding shape "* x q x 1 x 1" + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 1; + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_dim.d[i] = 0; + } else { + if (x_dim.d[i] < 0) { + reshape_before_fc_dim.d[x_num_col_dims] = -1; + break; + } + reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; + } + } + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *before_fc); + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("sparse_fc_op_reshape_before_fc: Shuffle (Output: " + output_name + + ")") + .c_str()); + return reshape_before_fc_layer; + } + + nvinfer1::ILayer* reshape_after_fc(nvinfer1::ITensor* after_fc, + nvinfer1::Dims x_dim, int x_num_col_dims) { + // add shuffle after fc + nvinfer1::Dims reshape_after_fc_dim; + reshape_after_fc_dim.nbDims = x_num_col_dims + 1; + for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { + reshape_after_fc_dim.d[i] = 0; + } + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *after_fc); + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + return reshape_after_fc_layer; + } + + plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias, + const std::string& activation_type, + nvinfer1::DataType type, + int outdim) { + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + if (activation_type == "relu") { + act = plugin::SpmmPluginDynamic::Activation::kRelu; + } else if (activation_type == "gelu") { + act = plugin::SpmmPluginDynamic::Activation::kGelu; + } else if (activation_type != "") { + PADDLE_THROW(paddle::platform::errors::Fatal("unknown activation_type %s", + activation_type.c_str())); + } + return new plugin::SpmmPluginDynamic("CustomSpmmPluginDynamic", type, + outdim, weight->get(), bias->get(), + act); + } + + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a sparse_fc op to tensorrt sparse_fc layer"; + framework::OpDesc op_desc(op, nullptr); + auto output_name = op_desc.Output("Out").front(); + auto input_names = op_desc.InputNames(); + bool with_bias = input_names.size() >= 3; + std::string w_name = "Y"; + std::string i_name = "X"; + if (with_bias) { + w_name = "W"; + i_name = "Input"; + } + // Declare inputs + auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); + auto x_dim = X->getDimensions(); + // Declare weights + auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); + PADDLE_ENFORCE_NOT_NULL( + Y_v, + platform::errors::NotFound( + "Can not find %s presistale var of sparse_fc in scope.", w_name)); + auto* Y_t = Y_v->GetMutable(); + int x_num_col_dims = + op_desc.HasAttr("x_num_col_dims") + ? BOOST_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")) + : (op_desc.HasAttr("in_num_col_dims") + ? BOOST_GET_CONST(int, op_desc.GetAttr("in_num_col_dims")) + : 1); + const std::string activation_type = + op_desc.HasAttr("activation_type") + ? BOOST_GET_CONST(std::string, op_desc.GetAttr("activation_type")) + : ""; + // This may trigger a GPU->CPU copy, because TRT's weight can only be + // assigned from CPU memory, which can't be avoided. + float* weight_data = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); + bool support_int8 = false; + if (op_desc.HasAttr("support_int8")) { + support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")); + } + float in_scale = 0; + if (enable_int8 || support_int8) { + if (enable_int8) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + } else { + // attr X is generated by add_support_int8_pass + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X")); + } + engine_->SetTensorDynamicRange(X, in_scale); + } + weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); + + PADDLE_ENFORCE_EQ( + Y_t->dims().size(), 2UL, + platform::errors::InvalidArgument( + "The sparse_fc's weight should be a matrix with 2 dims, but " + "it's %d-dimensional.", + Y_t->dims().size())); // a matrix + int m = Y_t->dims()[0]; + int n = Y_t->dims()[1]; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, + TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias) { + if (enable_int8 || support_int8) { + // add conv layer + float out_scale = 0; + if (enable_int8) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in sparse_fc layers in int8 mode")); + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + weight, bias, activation_type, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(inputs); + auto fc_layer_int8 = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer_int8->setName( + ("sparse_fc_op_int8: (Output: " + output_name + ")").c_str()); + engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); + auto* fc_after_reshape_int8 = reshape_after_fc( + fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); + + RreplenishLayerAndOutput(fc_after_reshape_int8, + "sparse_fc_op_int8_reshape_after_fc: Shuffle", + {output_name}, test_mode); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + weight, bias, activation_type, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(inputs); + auto fc_layer_float = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer_float->setName( + ("sparse_fc_op_float: FullyConnected (Output: " + output_name + ")") + .c_str()); + auto* fc_after_reshape_float = reshape_after_fc( + fc_layer_float->getOutput(0), x_dim, x_num_col_dims); + + RreplenishLayerAndOutput(fc_after_reshape_float, + "shuffle_after_sparse_fc", {output_name}, + test_mode); + } + }; + + bool transpose_y = false; + if (op_desc.HasAttr("transpose_Y")) { + transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); + } + int weight_w, weight_h; + if (!transpose_y) { + std::vector weight_data_tmp; + weight_data_tmp.reserve(Y_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + weight_w = n; + weight_h = m; + } else { + weight_w = m; + weight_h = n; + } + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[Y_t->numel()]; + for (int i = 0; i < Y_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + size_t n_output = weight_w; + TensorRTEngine::Weight weight{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + w_data, static_cast(Y_t->numel())}; + weight.dims.assign({weight_w, weight_h}); + + float* bias_data = nullptr; + int bias_num = 0; + if (with_bias) { + auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); + auto* b_t = b_v->GetMutable(); + bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); + bias_num = b_t->numel(); + } + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_num)}; + + // Running the TRT Static Shape mode: x_num_col_dims-1 + if (!engine_->with_dynamic_shape()) { + x_num_col_dims--; + } + // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can + // not add Shuffle layer in ernie's multihead. + if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && + x_dim.d[3] == 1 && x_num_col_dims == 2) { + if (enable_int8 || support_int8) { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, activation_type, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(X); + auto fc_layer_int8 = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + RreplenishLayerAndOutput(fc_layer_int8, "ernie_sparse_fc_op_int8: ", + {output_name}, test_mode); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, activation_type, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(X); + auto fc_layer_float = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + RreplenishLayerAndOutput(fc_layer_float, "ernie_sparse_fc_op_float", + {output_name}, test_mode); + } + } else { // need reshape input before and after fc + PADDLE_ENFORCE_GT( + x_dim.nbDims, x_num_col_dims, + platform::errors::InvalidArgument( + "Params and input dims mismatch. Paddle-TRT SPARSE_FC " + "converter expects x_dim.nbDims > x_num_col_dims, but " + "x_dim.nbDims : %d, x_num_col_dims : %d.", + x_dim.nbDims, x_num_col_dims)); + auto* reshape_before_fc_layer = + reshape_before_fc(X, x_dim, x_num_col_dims, output_name); + auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); + if (enable_int8 || support_int8) { + engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + } + regist_fc(reshape_itensor, n_output, &weight, &bias); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(sparse_fc, SparseFcOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 79a5e7d7a6a133..856985119b73e9 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -44,6 +44,10 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("reshape2"); int8_teller_set.insert("reshape"); int8_teller_set.insert("reshape2"); +#endif +#if IS_TRT_VERSION_GE(8000) + teller_set.insert("sparse_fc"); + int8_teller_set.insert("sparse_fc"); #endif } @@ -1751,6 +1755,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } +#if IS_TRT_VERSION_GE(8000) + if (op_type == "sparse_fc") { + if (!with_dynamic_shape) { + VLOG(3) << "the sparse_fc does not support static shape yet"; + return false; + } + } +#endif + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index ee1d6c1dc7d7ed..e5f2bd429f4701 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,22 +1,19 @@ +list(APPEND TRT_FILES trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu + gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu + instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu + qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu + hard_swish_op_plugin.cu stack_op_plugin.cu anchor_generator_op_plugin.cu + yolo_box_op_plugin.cu yolo_box_head_op_plugin.cu + roi_align_op_plugin.cu gather_nd_op_plugin.cu mish_op_plugin.cu pool3d_op_plugin.cu + deformable_conv_op_plugin.cu matmul_op_int8_plugin.cu transformer_input_convert_plugin.cu + remove_padding_plugin.cu recover_padding_plugin.cu) + +if (CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) + list(APPEND TRT_FILES spmm_plugin.cu) +endif() + nv_library(tensorrt_plugin - SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu - prelu_op_plugin.cu gelu_op_plugin.cu - pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu - instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu - qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu - hard_swish_op_plugin.cu stack_op_plugin.cu - anchor_generator_op_plugin.cu - yolo_box_op_plugin.cu - yolo_box_head_op_plugin.cu - roi_align_op_plugin.cu - gather_nd_op_plugin.cu - mish_op_plugin.cu - pool3d_op_plugin.cu - deformable_conv_op_plugin.cu - matmul_op_int8_plugin.cu - transformer_input_convert_plugin.cu - remove_padding_plugin.cu - recover_padding_plugin.cu + SRCS ${TRT_FILES} DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu new file mode 100644 index 00000000000000..b9cc7e55b7d2af --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -0,0 +1,882 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +nvinfer1::PluginFieldCollection SpmmPluginDynamicCreator::field_collection_{}; +std::vector SpmmPluginDynamicCreator::plugin_attr_; + +inline int getElementSize(nvinfer1::DataType type) { + switch (type) { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT8: + return 1; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "getElementSize only supports [FLOAT|HALF|INT8]")); + } +} + +inline cudaDataType_t convertTrtType(nvinfer1::DataType type) { + switch (type) { + case nvinfer1::DataType::kFLOAT: + return CUDA_R_32F; + case nvinfer1::DataType::kHALF: + return CUDA_R_16F; + case nvinfer1::DataType::kINT8: + return CUDA_R_8I; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "getElementSize only supports [FLOAT|HALF|INT8]")); + } +} + +inline void deserialize_value_size(void const** buffer, size_t* buffer_size, + void* value, size_t value_size) { + PADDLE_ENFORCE_GE( + *buffer_size, value_size, + platform::errors::InvalidArgument("buffer_size must >= value_size")); + memcpy(value, *buffer, value_size); + reinterpret_cast(*buffer) += value_size; + *buffer_size -= value_size; +} + +inline float round_scale(float x) { return std::floor(x + 0.5f); } + +inline void convertAndCopy(const nvinfer1::Weights& src, + nvinfer1::DataType type, void* dest) { + PADDLE_ENFORCE_EQ(src.type == nvinfer1::DataType::kFLOAT || + src.type == nvinfer1::DataType::kHALF, + true, + platform::errors::InvalidArgument( + "convertAndCopy only supports src type [FLOAT|HALF]")); + PADDLE_ENFORCE_EQ( + type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF, + true, platform::errors::InvalidArgument( + "convertAndCopy only supports src type [FLOAT|HALF]")); + + if (type == nvinfer1::DataType::kFLOAT) { + if (src.type == nvinfer1::DataType::kFLOAT) { + std::copy_n(static_cast(src.values), src.count, + static_cast(dest)); + } else { + for (int i = 0; i < src.count; ++i) { + static_cast(dest)[i] = + static_cast(static_cast(src.values)[i]); + } + } + } else { + if (src.type == nvinfer1::DataType::kHALF) { + std::copy_n(static_cast(src.values), src.count, + static_cast<__half*>(dest)); + } else { + for (int i = 0; i < src.count; ++i) { + static_cast<__half*>(dest)[i] = + static_cast<__half>(static_cast(src.values)[i]); + } + } + } +} + +SpmmPluginDynamic::cusparseLtContext::cusparseLtContext() { + paddle::platform::dynload::cusparseLtInit(&handle); +} + +SpmmPluginDynamic::cusparseLtContext::~cusparseLtContext() { + paddle::platform::dynload::cusparseLtDestroy(&handle); +} + +void SpmmPluginDynamic::cusparseLtContext::init( + int m, int n, int k, cudaDataType_t type, void* bias_ptr, + SpmmPluginDynamic::Activation activation) { + /* + 1. Init matrix descriptors (matA, matB, matC) + 2. Init matrix multiplication descriptor (matmul) + 3. Set activation and bias attribute of matmul + 4. Init algorithm selection descriptor (alg_sel) + 5. Init plan descriptor (plan) + */ + PADDLE_ENFORCE_EQ( + is_initialized, false, + platform::errors::InvalidArgument( + "Descriptor should be destroyed before calling create")); + constexpr int alignment = 16; + cusparseComputeType compute_type; + switch (type) { + case CUDA_R_32F: + compute_type = CUSPARSE_COMPUTE_TF32; + break; + case CUDA_R_16F: + compute_type = CUSPARSE_COMPUTE_16F; + break; + case CUDA_R_8I: + compute_type = CUSPARSE_COMPUTE_32I; + break; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "cusparLtContext only supports data type" + "[CUDA_R_32F|CUDA_R_16F|CUDA_R_8I]")); + } + paddle::platform::dynload::cusparseLtDenseDescriptorInit( + &handle, &matA, m, k, k, alignment, type, CUSPARSE_ORDER_ROW); + paddle::platform::dynload::cusparseLtStructuredDescriptorInit( + &handle, &matB, n, k, k, alignment, type, CUSPARSE_ORDER_ROW, + CUSPARSELT_SPARSITY_50_PERCENT); + paddle::platform::dynload::cusparseLtDenseDescriptorInit( + &handle, &matC, m, n, n, alignment, type, CUSPARSE_ORDER_ROW); + paddle::platform::dynload::cusparseLtMatmulDescriptorInit( + &handle, &matmul, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, &matA, &matB, &matC, &matC, compute_type); + if (activation == SpmmPluginDynamic::Activation::kRelu) { + int true_value = 1; + float relu_upper_bound = std::numeric_limits::max(); + float relu_threshold = 0.0f; + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU, &true_value, + sizeof(true_value)); + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU_UPPERBOUND, + &relu_upper_bound, sizeof(relu_upper_bound)); + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU_THRESHOLD, + &relu_threshold, sizeof(relu_threshold)); + } else if (activation == SpmmPluginDynamic::Activation::kGelu) { + int true_value = 1; + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_GELU, &true_value, + sizeof(true_value)); + } else { + PADDLE_ENFORCE_EQ( + activation, SpmmPluginDynamic::Activation::kNone, + platform::errors::InvalidArgument("Received unknown activation")); + } + if (bias_ptr != nullptr) { + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_BIAS_POINTER, &bias_ptr, + sizeof(bias_ptr)); + } + paddle::platform::dynload::cusparseLtMatmulAlgSelectionInit( + &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT); + int alg = 0; + paddle::platform::dynload::cusparseLtMatmulAlgSetAttribute( + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)); + paddle::platform::dynload::cusparseLtMatmulGetWorkspace(&handle, &alg_sel, + &workspace_size); + paddle::platform::dynload::cusparseLtMatmulPlanInit(&handle, &plan, &matmul, + &alg_sel, workspace_size); + is_initialized = true; +} + +void SpmmPluginDynamic::cusparseLtContext::setAlgo(int alg) { + PADDLE_ENFORCE_EQ( + is_initialized, true, + platform::errors::InvalidArgument( + "Descriptor should be initialized before setting algorithm")); + paddle::platform::dynload::cusparseLtMatmulAlgSetAttribute( + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)); + paddle::platform::dynload::cusparseLtMatmulGetWorkspace(&handle, &alg_sel, + &workspace_size); + paddle::platform::dynload::cusparseLtMatmulPlanDestroy(&plan); + paddle::platform::dynload::cusparseLtMatmulPlanInit(&handle, &plan, &matmul, + &alg_sel, workspace_size); +} + +void SpmmPluginDynamic::cusparseLtContext::destroy() { + PADDLE_ENFORCE_EQ(is_initialized, true, + platform::errors::InvalidArgument( + "cusparseLtContext is destroy before init")); + paddle::platform::dynload::cusparseLtMatmulPlanDestroy(&plan); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matC); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matB); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matA); + is_initialized = false; +} + +void SpmmPluginDynamic::cusparseLtContext::compressMatB( + int n, int k, cudaDataType_t type, void* src, void** dest, + size_t* compressed_size) { + PADDLE_ENFORCE_EQ( + is_initialized, false, + platform::errors::InvalidArgument( + "cusparseLtContext should not initialized before compressMatB")); + PADDLE_ENFORCE_EQ(*dest, nullptr, + platform::errors::InvalidArgument( + "before compressMatB *dest must be nullptr")); + constexpr int alignment = 16; + paddle::platform::dynload::cusparseLtStructuredDescriptorInit( + &handle, &matB, n, k, k, alignment, type, CUSPARSE_ORDER_ROW, + CUSPARSELT_SPARSITY_50_PERCENT); + + paddle::platform::dynload::cusparseLtSpMMACompressedSize2(&handle, &matB, + compressed_size); + cudaMalloc(dest, *compressed_size); + paddle::platform::dynload::cusparseLtSpMMACompress2( + &handle, &matB, 0, CUSPARSE_OPERATION_TRANSPOSE, src, *dest, nullptr); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matB); +} + +// Constructor for new plugin +SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, + const nvinfer1::DataType precision, + const int out_dim, + const nvinfer1::Weights& weight, + const nvinfer1::Weights& bias, + Activation activation) + : layer_name_(layer_name), + precision_(precision), + out_dim_(out_dim), + k_(0), + m_max_(0), + is_configured_(false), + optim_alg_(0), + weight_scale_(1.0f), + weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), + compressed_size_(0), + has_bias_(false), + bias_(nullptr), + bias_dev_(nullptr), + activation_(activation) { + /* + 1. Convert weight precision (on host) + 2. (Int8) Calculate scale and scale the weight (on host) + 3. Copy weight to device + 4. Compress the weight (on device) + 5. Copy the compressed weight to host + 6. Convert bias precision and copy (on host) + */ + precision_size_ = getElementSize(precision); + element_size_ = + (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); + + PADDLE_ENFORCE_EQ( + weight.count % out_dim, 0, + platform::errors::InvalidArgument( + "The size of weight should be divided by output dimension.")); + k_ = weight.count / out_dim; + PADDLE_ENFORCE_EQ( + weight.type == nvinfer1::DataType::kFLOAT || + weight.type == nvinfer1::DataType::kHALF, + true, platform::errors::InvalidArgument( + "SpmmPluginDynamic only supports weight of type [FLOAT|HALF]")); + nvinfer1::DataType weight_type; + if (precision_ == nvinfer1::DataType::kINT8) { + weight_type = nvinfer1::DataType::kFLOAT; + } else { + weight_type = precision_; + } + std::vector weight_host(element_size_ * out_dim_ * k_); + convertAndCopy(weight, weight_type, weight_host.data()); + void* weight_dev{nullptr}; + cudaMalloc(reinterpret_cast(&weight_dev), + precision_size_ * out_dim_ * k_); + if (precision == nvinfer1::DataType::kINT8) { + float max_weight{0.0f}; + for (int i = 0; i < weight.count; ++i) { + float local_abs = + std::abs(reinterpret_cast(weight_host.data())[i]); + max_weight = std::max(max_weight, local_abs); + } + weight_scale_ = max_weight / 127.0f; + std::vector scale_buffer(weight.count); + for (int i = 0; i < weight.count; ++i) { + scale_buffer[i] = static_cast( + round_scale(reinterpret_cast(weight_host.data())[i] / + weight_scale_)); + } + cudaMemcpy(weight_dev, scale_buffer.data(), precision_size_ * weight.count, + cudaMemcpyHostToDevice); + } else { + cudaMemcpy(weight_dev, weight_host.data(), precision_size_ * weight.count, + cudaMemcpyHostToDevice); + } + + spmm_context_.compressMatB(out_dim_, k_, convertTrtType(precision_), + weight_dev, &weight_compressed_dev_, + &compressed_size_); + weight_compressed_ = new char[compressed_size_]; + cudaMemcpy(weight_compressed_, weight_compressed_dev_, compressed_size_, + cudaMemcpyDeviceToHost); + + has_bias_ = (bias.count != 0); + if (has_bias_) { + if (bias.count != out_dim) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "The dimension of bias should be equal to output dimension")); + } + PADDLE_ENFORCE_EQ(bias.type, nvinfer1::DataType::kFLOAT, + platform::errors::InvalidArgument( + "SpmmPluginDynamic only supports FLOAT bias")); + + bias_ = new float[out_dim_]; + convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); + } + + cudaFree(weight_dev); +} + +// Constructor for clone +SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, + const nvinfer1::DataType precision, + const int out_dim, const int k, + const void* weight_compressed, + size_t compressed_size, const void* bias, + bool is_configured, const int m_max, + const int optim_alg, Activation activation) + : layer_name_(layer_name), + precision_(precision), + out_dim_(out_dim), + k_(k), + m_max_(m_max), + is_configured_(is_configured), + optim_alg_(optim_alg), + weight_scale_(1.0f), + weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), + compressed_size_(compressed_size), + has_bias_(false), + bias_(nullptr), + bias_dev_(nullptr), + activation_(activation) { + /* + 1. Copy the compressed weight (on host) + 2. Copy the compressed weight to device + 3. Copy the bias (on host) + 4. (Configured) Copy the bias to device + 5. (Configured) Init cuSPARSELt descriptors + */ + precision_size_ = getElementSize(precision); + element_size_ = + (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); + // Each plugin has a copy of compressed weight + weight_compressed_ = new char[compressed_size]; + std::copy_n(static_cast(weight_compressed), compressed_size, + static_cast(weight_compressed_)); + + has_bias_ = (bias != nullptr); + if (has_bias_) { + // Each plugin has a copy of bias + bias_ = new float[out_dim_]; + std::copy_n(static_cast(bias), sizeof(float) * out_dim_, + static_cast(bias_)); + if (is_configured_) { + cudaMalloc(reinterpret_cast(&bias_dev_), + sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + } + + if (is_configured_) { + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + spmm_context_.setAlgo(optim_alg_); + } +} + +SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, + size_t length) + : layer_name_(name), + weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), + bias_(nullptr), + bias_dev_(nullptr) { + DeserializeValue(&data, &length, &precision_); + DeserializeValue(&data, &length, &precision_size_); + DeserializeValue(&data, &length, &element_size_); + DeserializeValue(&data, &length, &out_dim_); + DeserializeValue(&data, &length, &k_); + DeserializeValue(&data, &length, &m_max_); + DeserializeValue(&data, &length, &is_configured_); + DeserializeValue(&data, &length, &optim_alg_); + DeserializeValue(&data, &length, &weight_scale_); + DeserializeValue(&data, &length, &compressed_size_); + DeserializeValue(&data, &length, &has_bias_); + DeserializeValue(&data, &length, &activation_); + + PADDLE_ENFORCE_EQ(is_configured_, true, + platform::errors::InvalidArgument( + "Deserialize data should be configured")); + weight_compressed_ = new char[compressed_size_]; + deserialize_value_size(&data, &length, weight_compressed_, compressed_size_); + cudaMalloc(reinterpret_cast(&weight_compressed_dev_), + compressed_size_); + cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, + cudaMemcpyHostToDevice); + + if (has_bias_) { + bias_ = new float[out_dim_]; + deserialize_value_size(&data, &length, bias_, sizeof(float) * out_dim_); + cudaMalloc(reinterpret_cast(&bias_dev_), sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + + if (is_configured_) { + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + spmm_context_.setAlgo(optim_alg_); + } +} + +nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept { + try { + auto* p = + new SpmmPluginDynamic(layer_name_, precision_, out_dim_, k_, + weight_compressed_, compressed_size_, bias_, + is_configured_, m_max_, optim_alg_, activation_); + p->weight_scale_ = weight_scale_; + p->setPluginNamespace(namespace_.c_str()); + p->weight_compressed_dev_ = weight_compressed_dev_; + + return p; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +nvinfer1::DimsExprs SpmmPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + int nbDims = inputs[0].nbDims; + try { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs is invalid")); + PADDLE_ENFORCE_EQ(outputIndex, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's outputIndex is invalid")); + if (nbDims == 5) { + int nbDims = inputs[0].nbDims; + PADDLE_ENFORCE_EQ( + inputs[0].d[3]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[3] should be 1")); + PADDLE_ENFORCE_EQ( + inputs[0].d[4]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[4] should be 1")); + nvinfer1::DimsExprs ret; + ret.nbDims = nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(out_dim_); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + return ret; + } else if (nbDims == 4) { + int nbDims = inputs[0].nbDims; + PADDLE_ENFORCE_EQ( + inputs[0].d[2]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[2] should be 1")); + PADDLE_ENFORCE_EQ( + inputs[0].d[3]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[3] should be 1")); + nvinfer1::DimsExprs ret; + ret.nbDims = nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(out_dim_); + ret.d[2] = exprBuilder.constant(1); + ret.d[3] = exprBuilder.constant(1); + + return ret; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal("nbDims should be 4 or 5")); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nvinfer1::DimsExprs{}; +} + +bool SpmmPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) noexcept { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(nbOutputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbOutputs should be 1")); + + const nvinfer1::PluginTensorDesc& in = inOut[pos]; + if (pos == 0) { + return (in.type == precision_) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc& prev = inOut[pos - 1]; + + return in.type == prev.type && in.format == prev.format; +} + +void SpmmPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) noexcept { + /* + The following steps are executed if not configured. + 1. (INT8) Scale the bias (on host) + 2. Copy the bias to device + 3. Search the optimal algorithm + */ + try { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(nbOutputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbOutputs should be 1")); + PADDLE_ENFORCE_EQ(precision_, inputs[0].desc.type, + platform::errors::InvalidArgument( + "precision_ should be equal to inputs[0].desc.type")); + const auto& inDims0 = inputs[0].desc.dims; + PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( + "inDims0.nbDims should be 5")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[2], + platform::errors::InvalidArgument( + "inDims0.d[2] should be equals to k")); + PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( + "inDims0.d[3] should be 1")); + PADDLE_ENFORCE_EQ(inDims0.d[4], 1, platform::errors::InvalidArgument( + "inDims0.d[4] should be 1")); + const int BS = inputs->max.d[0]; + + // The optimal algorighm id is for m = m_max_ + // To Do: configurePlugin takes time when m is changed + if (is_configured_) { + return; + } + + m_max_ = BS; + if (has_bias_) { + if (inputs->desc.type == nvinfer1::DataType::kINT8) { + for (int i = 0; i < out_dim_; ++i) { + static_cast(bias_)[i] = + static_cast(bias_)[i] / outputs->desc.scale; + } + } + cudaMalloc(reinterpret_cast(&bias_dev_), + sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + + void* dA; + void* dC; + void* d_workspace; + float alpha{1.0f}; + float beta{0.0f}; + if (precision_ == nvinfer1::DataType::kINT8) { + alpha = inputs->desc.scale * weight_scale_ / outputs->desc.scale; + } + cudaMalloc(reinterpret_cast(&dA), m_max_ * k_ * sizeof(dataType)); + cudaMalloc(reinterpret_cast(&dC), + m_max_ * out_dim_ * sizeof(dataType)); + cudaMalloc(reinterpret_cast(&d_workspace), + spmm_context_.workspace_size); + paddle::platform::dynload::cusparseLtMatmulSearch( + &spmm_context_.handle, &spmm_context_.plan, &alpha, dA, + weight_compressed_dev_, &beta, dC, dC, d_workspace, nullptr, 0); + paddle::platform::dynload::cusparseLtMatmulAlgGetAttribute( + &spmm_context_.handle, &spmm_context_.alg_sel, + CUSPARSELT_MATMUL_ALG_CONFIG_ID, &optim_alg_, sizeof(optim_alg_)); + cudaFree(dA); + cudaFree(dC); + cudaFree(d_workspace); + + is_configured_ = true; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +size_t SpmmPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept { + return spmm_context_.workspace_size; +} + +int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, + void* workSpace, cudaStream_t stream) noexcept { + try { + PADDLE_ENFORCE_EQ(is_configured_, true, + platform::errors::InvalidArgument( + "The plugin is not configured before enqueue")); + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[2], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + float alpha = 1.0f; + float beta = 0.0f; + if (inputDesc->type == nvinfer1::DataType::kFLOAT) { + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else if (inputDesc->type == nvinfer1::DataType::kHALF) { + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else if (inputDesc->type == nvinfer1::DataType::kINT8) { + alpha = inputDesc->scale * weight_scale_ / outputDesc->scale; + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupported type error, expected [kHALF,kFLOAT], but received %d", + static_cast(precision_))); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return -1; +} + +nvinfer1::DataType SpmmPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, int nbInputs) const + noexcept { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "SpmmPluginDynamic's index should be 0")); + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(inputTypes[0] == nvinfer1::DataType::kFLOAT || + inputTypes[0] == nvinfer1::DataType::kHALF || + inputTypes[0] == nvinfer1::DataType::kINT8, + true, + platform::errors::InvalidArgument( + "SpmmPluginDynamic is not support this format now")); + + return inputTypes[0]; +} + +const char* SpmmPluginDynamic::getPluginType() const noexcept { + return "SpmmPluginDynamic"; +} + +const char* SpmmPluginDynamic::getPluginVersion() const noexcept { return "1"; } + +int SpmmPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int SpmmPluginDynamic::initialize() noexcept { return 0; } + +void SpmmPluginDynamic::terminate() noexcept {} + +size_t SpmmPluginDynamic::getSerializationSize() const noexcept { + return compressed_size_ + (has_bias_ ? sizeof(float) * out_dim_ : 0) + + sizeof(precision_) + sizeof(precision_size_) + sizeof(element_size_) + + sizeof(out_dim_) + sizeof(k_) + sizeof(m_max_) + + sizeof(is_configured_) + sizeof(optim_alg_) + sizeof(weight_scale_) + + sizeof(compressed_size_) + sizeof(has_bias_) + sizeof(activation_); +} + +void SpmmPluginDynamic::serialize(void* buffer) const noexcept { + SerializeValue(&buffer, precision_); + SerializeValue(&buffer, precision_size_); + SerializeValue(&buffer, element_size_); + SerializeValue(&buffer, out_dim_); + SerializeValue(&buffer, k_); + SerializeValue(&buffer, m_max_); + SerializeValue(&buffer, is_configured_); + SerializeValue(&buffer, optim_alg_); + SerializeValue(&buffer, weight_scale_); + SerializeValue(&buffer, compressed_size_); + SerializeValue(&buffer, has_bias_); + SerializeValue(&buffer, activation_); + + char* d = static_cast(buffer); + std::copy_n(static_cast(weight_compressed_), compressed_size_, + d); + if (has_bias_) { + d += compressed_size_; + std::copy_n(static_cast(bias_), out_dim_ * sizeof(float), d); + } +} + +void SpmmPluginDynamic::destroy() noexcept { + delete[] reinterpret_cast(weight_compressed_); + if (weight_compressed_dev_) { + cudaFree(weight_compressed_dev_); + weight_compressed_dev_ = nullptr; + } + if (has_bias_) { + cudaFree(bias_dev_); + } + if (is_configured_) { + spmm_context_.destroy(); + } + delete this; +} + +void SpmmPluginDynamic::setPluginNamespace(const char* libNamespace) noexcept { + try { + namespace_ = libNamespace; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +const char* SpmmPluginDynamic::getPluginNamespace() const noexcept { + return namespace_.c_str(); +} + +inline nvinfer1::DataType fieldTypeToDataType( + const nvinfer1::PluginFieldType ftype) { + switch (ftype) { + case nvinfer1::PluginFieldType::kFLOAT32: + return nvinfer1::DataType::kFLOAT; + case nvinfer1::PluginFieldType::kFLOAT16: + return nvinfer1::DataType::kHALF; + case nvinfer1::PluginFieldType::kINT32: + return nvinfer1::DataType::kINT32; + case nvinfer1::PluginFieldType::kINT8: + return nvinfer1::DataType::kINT8; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "No corresponding datatype for plugin field type")); + } +} + +SpmmPluginDynamicCreator::SpmmPluginDynamicCreator() { + plugin_attr_.emplace_back(nvinfer1::PluginField( + "type_id", nullptr, nvinfer1::PluginFieldType::kINT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "out_dim", nullptr, nvinfer1::PluginFieldType::kINT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "weight", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "bias", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "activation_id", nullptr, nvinfer1::PluginFieldType::kINT8, 1)); + + field_collection_.nbFields = plugin_attr_.size(); + field_collection_.fields = plugin_attr_.data(); +} + +const char* SpmmPluginDynamicCreator::getPluginName() const noexcept { + return "SpmmPluginDynamic"; +} + +const char* SpmmPluginDynamicCreator::getPluginVersion() const noexcept { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +SpmmPluginDynamicCreator::getFieldNames() noexcept { + return &field_collection_; +} + +nvinfer1::IPluginV2* SpmmPluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept { + try { + int type_id = -1; + int out_dim = 0; + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, nullptr, 0ll}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0ll}; + int activation_id = -1; + + for (int i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("type_id") == 0) { + type_id = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("out_dim") == 0) { + out_dim = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("weight") == 0) { + weight.type = fieldTypeToDataType(fc->fields[i].type); + weight.values = fc->fields[i].data; + weight.count = fc->fields[i].length; + } else if (field_name.compare("bias") == 0) { + bias.type = fieldTypeToDataType(fc->fields[i].type); + bias.values = fc->fields[i].data; + bias.count = fc->fields[i].length; + } else if (field_name.compare("activation_id") == 0) { + activation_id = static_cast(fc->fields[i].data)[0]; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal("Unsupport plugin field")); + } + } + + PADDLE_ENFORCE_NE( + type_id, -1, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's type_id should not be -1")); + PADDLE_ENFORCE_NE( + out_dim, 0, platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's out_dim should not be 0")); + PADDLE_ENFORCE_NE( + weight.count, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's weight size should not be 0")); + PADDLE_ENFORCE_NE( + activation_id, -1, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's activation_id should not be -1")); + nvinfer1::DataType type = static_cast(type_id); + SpmmPluginDynamic::Activation activation = + static_cast(activation_id); + return new SpmmPluginDynamic(name, type, out_dim, weight, bias, activation); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +nvinfer1::IPluginV2* SpmmPluginDynamicCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call SpmmPluginDynamic::destroy() + try { + return new SpmmPluginDynamic(name, serialData, serialLength); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +void SpmmPluginDynamicCreator::setPluginNamespace( + const char* libNamespace) noexcept { + try { + namespace_ = libNamespace; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +const char* SpmmPluginDynamicCreator::getPluginNamespace() const noexcept { + return namespace_.c_str(); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h new file mode 100644 index 00000000000000..a7edb8dedfa7fd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include "NvInfer.h" +#include "NvInferPlugin.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/dynload/cusparseLt.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + enum class Activation { kNone, kRelu, kGelu }; + SpmmPluginDynamic(const std::string& name, const nvinfer1::DataType precision, + const int out_dim, const nvinfer1::Weights& weight, + const nvinfer1::Weights& bias, Activation activation); + // The second constructor is for clone member function + SpmmPluginDynamic(const std::string& name, const nvinfer1::DataType precision, + const int out_dim, const int k, const void* weight, + size_t compressed_size, const void* bias, + bool is_configured, const int m_max, const int optim_alg, + Activation activation); + SpmmPluginDynamic(const std::string name, const void* data, size_t length); + SpmmPluginDynamic() = delete; + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const noexcept override; + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + + private: + struct cusparseLtContext { + cusparseLtHandle_t handle; + cusparseLtMatDescriptor_t matA; + cusparseLtMatDescriptor_t matB; + cusparseLtMatDescriptor_t matC; + cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulAlgSelection_t alg_sel; + cusparseLtMatmulPlan_t plan; + cusparseLtContext(); + ~cusparseLtContext(); + size_t workspace_size{0}; + bool is_initialized{false}; + int activation{0}; + float relu_upper_bound{0}; + float relu_threshold{0}; + void init(int m, int n, int k, cudaDataType_t type, void* bias_ptr, + SpmmPluginDynamic::Activation activation); + void setAlgo(int id); + void destroy(); + void compressMatB(int n, int k, cudaDataType_t type, void* src, void** dest, + size_t* compressed_size); + }; // struct SpmmPluginDynamic::cusparseLtContext + const std::string layer_name_; + std::string namespace_; + nvinfer1::DataType precision_; + size_t precision_size_; + size_t + element_size_; // size of weight (float if INT8 or FLOAT; half if HALF) + int out_dim_; + int k_; + int m_max_; + bool is_configured_; // already get m, scale bias, and search the optim alg + // or not + int optim_alg_; // the index of optimal algorithm + float weight_scale_; // record the weight scale from constructor + void* weight_compressed_; // host compressed weight + void* weight_compressed_dev_; // device compressed weight + size_t compressed_size_; // size of compressed weight + bool has_bias_; // there is bias or not + void* bias_; // host bias + void* bias_dev_; // device bias + Activation activation_; // record the activation type + cusparseLtContext spmm_context_; +}; // class SpmmPluginDynamic + +class SpmmPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + SpmmPluginDynamicCreator(); + const char* getPluginName() const noexcept override; + const char* getPluginVersion() const noexcept override; + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::IPluginV2* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection field_collection_; + static std::vector plugin_attr_; + std::string namespace_; +}; // class SpmmPluginDynamicCreator + +REGISTER_TRT_PLUGIN_V2(SpmmPluginDynamicCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 1f95e121271041..07b6f2ead26447 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -29,6 +29,10 @@ if (TENSORRT_FOUND) list(APPEND CUDA_SRCS tensorrt.cc) endif() +if (CUSPARSELT_FOUND) + list(APPEND CUDA_SRCS cusparseLt.cc) +endif() + configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) diff --git a/paddle/fluid/platform/dynload/cusparseLt.cc b/paddle/fluid/platform/dynload/cusparseLt.cc new file mode 100644 index 00000000000000..ae2aec012b7b77 --- /dev/null +++ b/paddle/fluid/platform/dynload/cusparseLt.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cusparseLt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +#ifdef CUSPARSELT_ROUTINE_EACH +CUSPARSELT_ROUTINE_EACH(DEFINE_WRAP); +#endif + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusparseLt.h b/paddle/fluid/platform/dynload/cusparseLt.h new file mode 100644 index 00000000000000..72e5fa4f88706b --- /dev/null +++ b/paddle/fluid/platform/dynload/cusparseLt.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include // NOLINT + +#include "paddle/phi/backends/dynload/cusparseLt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +#define PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \ + using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ + extern DynLoad__##__name __name + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11020 +#define CUSPARSELT_ROUTINE_EACH(__macro) \ + __macro(cusparseLtInit); \ + __macro(cusparseLtDestroy); \ + __macro(cusparseLtDenseDescriptorInit); \ + __macro(cusparseLtStructuredDescriptorInit); \ + __macro(cusparseLtMatmulDescriptorInit); \ + __macro(cusparseLtMatmulDescSetAttribute); \ + __macro(cusparseLtMatmulAlgSelectionInit); \ + __macro(cusparseLtMatmulAlgSetAttribute); \ + __macro(cusparseLtMatmulGetWorkspace); \ + __macro(cusparseLtMatmulPlanInit); \ + __macro(cusparseLtMatDescriptorDestroy); \ + __macro(cusparseLtSpMMACompressedSize2); \ + __macro(cusparseLtSpMMACompress2); \ + __macro(cusparseLtMatmulSearch); \ + __macro(cusparseLtMatmulAlgGetAttribute); \ + __macro(cusparseLtMatmulPlanDestroy); \ + __macro(cusparseLtMatmul); \ + __macro(cusparseGetErrorString); + +CUSPARSELT_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP); +#endif +#endif + +#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 5ce63b244efde5..5ef9616ab4364a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -71,6 +71,10 @@ void* GetCUFFTDsoHandle() { return phi::dynload::GetCUFFTDsoHandle(); } void* GetMKLRTDsoHandle() { return phi::dynload::GetMKLRTDsoHandle(); } +void* GetCusparseLtDsoHandle() { + return phi::dynload::GetCusparseLtDsoHandle(); +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index ca60cd76a59e10..50714dfb302ebf 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -46,6 +46,7 @@ void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); void* GetMKLRTDsoHandle(); void* GetROCFFTDsoHandle(); +void* GetCusparseLtDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index bc5ef3cd5c0787..36408c46e516dc 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -29,6 +29,10 @@ if (TENSORRT_FOUND) list(APPEND CUDA_SRCS tensorrt.cc) endif() +if (CUSPARSELT_FOUND) + list(APPEND CUDA_SRCS cusparseLt.cc) +endif() + configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) diff --git a/paddle/phi/backends/dynload/cusparseLt.cc b/paddle/phi/backends/dynload/cusparseLt.cc new file mode 100644 index 00000000000000..9025a1b82ca3f2 --- /dev/null +++ b/paddle/phi/backends/dynload/cusparseLt.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/dynload/cusparseLt.h" + +namespace phi { +namespace dynload { + +std::once_flag cusparselt_dso_flag; +void *cusparselt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUSPARSELT_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/cusparseLt.h b/paddle/phi/backends/dynload/cusparseLt.h new file mode 100644 index 00000000000000..b67858cd72fb02 --- /dev/null +++ b/paddle/phi/backends/dynload/cusparseLt.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include // NOLINT + +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/port.h" + +namespace phi { +namespace dynload { + +extern std::once_flag cusparselt_dso_flag; +extern void *cusparselt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load cupti routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cusparseStatus_t operator()(Args... args) { \ + using cusparseltFunc = decltype(&::__name); \ + std::call_once(cusparselt_dso_flag, []() { \ + cusparselt_dso_handle = phi::dynload::GetCusparseLtDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(cusparselt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11020 +#define CUSPARSELT_ROUTINE_EACH(__macro) \ + __macro(cusparseLtInit); \ + __macro(cusparseLtDestroy); \ + __macro(cusparseLtDenseDescriptorInit); \ + __macro(cusparseLtStructuredDescriptorInit); \ + __macro(cusparseLtMatmulDescriptorInit); \ + __macro(cusparseLtMatmulDescSetAttribute); \ + __macro(cusparseLtMatmulAlgSelectionInit); \ + __macro(cusparseLtMatmulAlgSetAttribute); \ + __macro(cusparseLtMatmulGetWorkspace); \ + __macro(cusparseLtMatmulPlanInit); \ + __macro(cusparseLtMatDescriptorDestroy); \ + __macro(cusparseLtSpMMACompressedSize2); \ + __macro(cusparseLtSpMMACompress2); \ + __macro(cusparseLtMatmulSearch); \ + __macro(cusparseLtMatmulAlgGetAttribute); \ + __macro(cusparseLtMatmulPlanDestroy); \ + __macro(cusparseLtMatmul); \ + __macro(cusparseGetErrorString); + +CUSPARSELT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP); +#endif +#endif + +#undef DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 2f35e22a18f820..36a78695959235 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -76,6 +76,8 @@ DEFINE_string(mkl_dir, DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); +DEFINE_string(cusparselt_dir, "", "Specify path for loading libcusparseLt.so."); + #ifdef PADDLE_WITH_HIP DEFINE_string(miopen_dir, @@ -578,5 +580,18 @@ void* GetMKLRTDsoHandle() { #endif } +void* GetCusparseLtDsoHandle() { +// APIs available after CUDA 11.2 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 + return GetDsoHandleFromSearchPath(FLAGS_cusparselt_dir, "libcusparseLt.so"); +#else + std::string warning_msg( + "Your CUDA_VERSION less 11.2, not support cusparseLt. " + "If you want to use cusparseLt, please upgrade CUDA and rebuild " + "PaddlePaddle."); + return nullptr; +#endif +} + } // namespace dynload } // namespace phi diff --git a/paddle/phi/backends/dynload/dynamic_loader.h b/paddle/phi/backends/dynload/dynamic_loader.h index 942a635b649bcd..642535fc50cf3e 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.h +++ b/paddle/phi/backends/dynload/dynamic_loader.h @@ -45,6 +45,7 @@ void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); void* GetMKLRTDsoHandle(); void* GetROCFFTDsoHandle(); +void* GetCusparseLtDsoHandle(); void SetPaddleLibPath(const std::string&);