From dc4604078810c730f1df88f9fe86997c3361f23d Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Fri, 20 May 2022 11:40:26 +0000 Subject: [PATCH 01/11] fix precision-is-non error --- .../inference/tensorrt/plugin/spmm_plugin.cu | 61 +++++++++++++------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index b9cc7e55b7d2af..389a0807a5d3da 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -373,6 +373,11 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, std::copy_n(static_cast(weight_compressed), compressed_size, static_cast(weight_compressed_)); + cudaMalloc(reinterpret_cast(&weight_compressed_dev_), + compressed_size); + cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size, + cudaMemcpyHostToDevice); + has_bias_ = (bias != nullptr); if (has_bias_) { // Each plugin has a copy of bias @@ -447,7 +452,6 @@ nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept { is_configured_, m_max_, optim_alg_, activation_); p->weight_scale_ = weight_scale_; p->setPluginNamespace(namespace_.c_str()); - p->weight_compressed_dev_ = weight_compressed_dev_; return p; } catch (const std::exception& e) { @@ -548,24 +552,38 @@ void SpmmPluginDynamic::configurePlugin( platform::errors::InvalidArgument( "precision_ should be equal to inputs[0].desc.type")); const auto& inDims0 = inputs[0].desc.dims; - PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( - "inDims0.nbDims should be 5")); - PADDLE_ENFORCE_EQ(k_, inDims0.d[2], - platform::errors::InvalidArgument( - "inDims0.d[2] should be equals to k")); - PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( - "inDims0.d[3] should be 1")); - PADDLE_ENFORCE_EQ(inDims0.d[4], 1, platform::errors::InvalidArgument( - "inDims0.d[4] should be 1")); - const int BS = inputs->max.d[0]; - + if (inDims0.nbDims==5) { + PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( + "inDims0.nbDims should be 5")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[2], + platform::errors::InvalidArgument( + "inDims0.d[2] should be equals to k")); + PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( + "inDims0.d[3] should be 1")); + PADDLE_ENFORCE_EQ(inDims0.d[4], 1, platform::errors::InvalidArgument( + "inDims0.d[4] should be 1")); + const int BS = inputs->max.d[0]; + const int Seq = inputs->max.d[1]; + m_max_ = BS * Seq; + } else if (inDims0.nbDims==4) { + PADDLE_ENFORCE_EQ(inDims0.nbDims, 4, platform::errors::InvalidArgument( + "inDims0.nbDims should be 4")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[1], + platform::errors::InvalidArgument( + "inDims0.d[1] should be equals to k")); + PADDLE_ENFORCE_EQ(inDims0.d[2], 1, platform::errors::InvalidArgument( + "inDims0.d[2] should be 1")); + PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( + "inDims0.d[3] should be 1")); + const int BS_Seq = inputs->max.d[0]; + m_max_ = BS_Seq; + } // The optimal algorighm id is for m = m_max_ // To Do: configurePlugin takes time when m is changed if (is_configured_) { return; } - m_max_ = BS; if (has_bias_) { if (inputs->desc.type == nvinfer1::DataType::kINT8) { for (int i = 0; i < out_dim_; ++i) { @@ -624,9 +642,15 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, PADDLE_ENFORCE_EQ(is_configured_, true, platform::errors::InvalidArgument( "The plugin is not configured before enqueue")); - PADDLE_ENFORCE_EQ( - k_, inputDesc->dims.d[2], - platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + if (inputDesc->dims.nbDims==5){ + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[2], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + } else if (inputDesc->dims.nbDims==4) { + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[1], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[1]")); + } float alpha = 1.0f; float beta = 0.0f; if (inputDesc->type == nvinfer1::DataType::kFLOAT) { @@ -725,10 +749,7 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { void SpmmPluginDynamic::destroy() noexcept { delete[] reinterpret_cast(weight_compressed_); - if (weight_compressed_dev_) { - cudaFree(weight_compressed_dev_); - weight_compressed_dev_ = nullptr; - } + cudaFree(weight_compressed_dev_); if (has_bias_) { cudaFree(bias_dev_); } From 9ad5f6467c73a0d7cd68cc64ec6077fc72f68fdf Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Mon, 23 May 2022 04:52:53 +0000 Subject: [PATCH 02/11] develop finished --- .../framework/ir/graph_pattern_detector.cc | 29 ++ .../framework/ir/graph_pattern_detector.h | 18 + ...dense_multihead_matmul_with_sparse_pass.cc | 123 +++++ ..._dense_multihead_matmul_with_sparse_pass.h | 45 ++ .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_pass_builder.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../convert/sparse_multihead_matmul_op.cc | 456 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 6 +- 9 files changed, 678 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc create mode 100644 paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h create mode 100644 paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e752cead428a81..dcf58185930bb7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3449,6 +3449,35 @@ PDNode *patterns::DenseFC::operator()() { return fc_out; } +PDNode *patterns::DenseMultiheadMatmul::operator()() { + auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())->assert_is_op("multihead_matmul"); + // Input + auto *multihead_matmul_input = pattern->NewNode(multihead_matmul_input_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Input"); + // Filter + auto *multihead_matmul_weights = pattern->NewNode(multihead_matmul_weights_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "W"); + // Bias + auto *multihead_matmul_bias = pattern->NewNode(multihead_matmul_bias_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Bias"); + // BiasQK + auto *multihead_matmul_biasqk = pattern->NewNode(multihead_matmul_biasqk_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "BiasQK"); + // Output + auto *multihead_matmul_out = pattern->NewNode(multihead_matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("multihead_matmul", "Out") + ->assert_is_only_output_of_op("fc"); + + multihead_matmul->LinksFrom({multihead_matmul_input, multihead_matmul_weights, multihead_matmul_bias, multihead_matmul_biasqk}).LinksTo({multihead_matmul_out}); + + return multihead_matmul_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 538adcf4510784..4be8eb9140df5e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1958,6 +1958,24 @@ struct DenseFC : public PatternBase { PATTERN_DECL_NODE(fc_bias); }; +// +// \brief Pattern looking for dense multihead matmul fc. +// +struct DenseMultiheadMatmul : public PatternBase { + DenseMultiheadMatmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dense_multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(matmul); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(matmul_input); + PATTERN_DECL_NODE(matmul_weights); + PATTERN_DECL_NODE(matmul_bias); + PATTERN_DECL_NODE(matmul_biasqk); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc new file mode 100644 index 00000000000000..15a1ca15c61641 --- /dev/null +++ b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +ReplaceDenseMultiheadMatmulWithSparsePass::ReplaceDenseMultiheadMatmulWithSparsePass() { + AddOpCompat(OpCompat("multihead_matmul")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("BiasQK") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + std::string name_scope = "replace_dense_multihead_matmul_with_sparse_pass"; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + + patterns::DenseMultiheadMatmul dense_multihead_matmul_pattern(gpd.mutable_pattern(), + "dense_multihead_matmul_replace_pass"); + dense_multihead_matmul_pattern(); + int found_dense_multihead_matmul_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Replace dense multihead matmul with sparse_multihead_matmul."; + + /* if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + }*/ + + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, multihead_matmul_weights, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, dense_multihead_matmul_pattern); + + auto *multihead_matmul_op = multihead_matmul->Op(); + auto w_name = multihead_matmul_op->Input("W")[0]; + // recognize sparse op by name + if (w_name.find("sparse_2_4") != w_name.npos) { + // fake op + OpDesc desc(multihead_matmul_op->Block()); + desc.SetType("sparse_multihead_matmul"); + desc.SetInput("Input", {multihead_matmul_input->Name()}); + desc.SetInput("W", {multihead_matmul_weights->Name()}); + desc.SetInput("Bias", {multihead_matmul_bias->Name()}); + desc.SetInput("BiasQK", {multihead_matmul_biasqk->Name()}); + desc.SetOutput("Out", {multihead_matmul_out->Name()}); + + // copy all attr + desc.SetAttr("alpha", multihead_matmul_op->GetAttr("alpha")); + desc.SetAttr("head_number", multihead_matmul_op->GetAttr("head_number")); + if (multihead_matmul_op->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", multihead_matmul_op->GetAttr("Input_scale")); + } + if (multihead_matmul_op->HasAttr("fc_out_threshold")) { + desc.SetAttr("fc_out_threshold", multihead_matmul_op->GetAttr("fc_out_threshold")); + } + if (multihead_matmul_op->HasAttr("qkv2context_plugin_int8")) { + desc.SetAttr("qkv2context_plugin_int8", multihead_matmul_op->GetAttr("qkv2context_plugin_int8")); + } + if (multihead_matmul_op->HasAttr("dp_probs")) { + desc.SetAttr("dp_probs", multihead_matmul_op->GetAttr("dp_probs")); + } + if (multihead_matmul_op->HasAttr("out_threshold")) { + desc.SetAttr("out_threshold", multihead_matmul_op->GetAttr("out_threshold")); + } + desc.Flush(); + GraphSafeRemoveNodes(g, {multihead_matmul}); + auto sparse_multihead_matmul_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(multihead_matmul_input, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_weights, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_bias, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_biasqk, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(sparse_multihead_matmul_node, multihead_matmul_out); + found_dense_multihead_matmul_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_dense_multihead_matmul_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(replace_dense_multihead_matmul_with_sparse_pass, + paddle::framework::ir::ReplaceDenseMultiheadMatmulWithSparsePass); diff --git a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h new file mode 100644 index 00000000000000..4c43e9ba4efbb9 --- /dev/null +++ b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Replace dense multihead_matmul op with sparse multihead_matmul op + */ +class Graph; + +class ReplaceDenseMultiheadMatmulWithSparsePass : public FusePassBase { + public: + ReplaceDenseMultiheadMatmulWithSparsePass(); + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"replace_dense_multihead_matmul_with_sparse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2447852c69df40..d113590e42d1a3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1767,6 +1767,7 @@ USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) +USE_TRT_CONVERTER(sparse_multihead_matmul) #endif #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2ddb0bc54d2ed0..9d7ff1ff2048d9 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -109,6 +109,7 @@ const std::vector kTRTSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "replace_dense_with_sparse_pass", // + "replace_dense_multihead_matmul_with_sparse_pass", // "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index ebbc93426ba26f..c986710c3c58b1 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -10,7 +10,7 @@ conv3d_op.cc mish_op.cc nearest_interp_v2_op.cc pool3d_op.cc deformable_conv_op. preln_skip_layernorm.cc strided_slice_op.cc roll_op.cc) if (CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) - list(APPEND CONVERT_FILES sparse_fc_op.cc) + list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) endif() nv_library(tensorrt_converter diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc new file mode 100644 index 00000000000000..3f36a5743f2702 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -0,0 +1,456 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SparseMultiheadMatMulOpConverter : public OpConverter { + public: + plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias, + nvinfer1::DataType type, + int outdim) { + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + return new plugin::SpmmPluginDynamic("CustomSpmmPluginDynamic", type, + outdim, weight->get(), bias->get(), + act); + } + + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding tensorrt " + "network structure"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + + // fc weights and fc bias + auto weight_name = op_desc.Input("W").front(); + auto bias_name = op_desc.Input("Bias").front(); + + auto* weight_v = scope.FindVar(weight_name); + auto* weight_t = weight_v->GetMutable(); + + auto* bias_v = scope.FindVar(bias_name); + auto* bias_t = bias_v->GetMutable(); + + float* weight_data = nullptr; + bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); + float in_scale = 0.; + + if (op_desc.HasAttr("Input_scale")) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + engine_->SetTensorDynamicRange(input, in_scale); + } + weight_data = engine_->GetWeightCPUData(weight_name, weight_t); + + float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + + // (hidden_in, 3, hidden_out) + auto weight_dims = weight_t->dims(); + + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // channels_out + int hidden_out = weight_dims[2]; // channels_out + int m = hidden_in; + int n = three * hidden_out; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + + int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number")); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; + + if (engine_->with_dynamic_shape()) { + if (engine_->use_oss()) { + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + PADDLE_THROW(platform::errors::Fatal( + "use use_oss must be int8 or half, not float32.")); + } + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + if (engine_->with_interleaved()) { + VLOG(4) << "fused sparse_multihead_matmul op: use_oss and with_interleaved"; + if (!op_desc.HasAttr("Input_scale")) { + PADDLE_THROW( + platform::errors::Fatal("use with_interleaved must be int8.")); + } + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(input); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer->setName( + ("SparseMultihead: SPMM: (Output: " + + output_name + ")") + .c_str()); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out_threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "3"); + assert(creator != nullptr); + std::vector fields{ + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + if (engine_->Has("ernie_pos_name")) { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->Get("ernie_pos_name"))); + } else { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network() + ->getInput(2) + ->getName())); // cu_seqlens, eval_placeholder_2 + } + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + shuffle_layer->setName( + ("Sparse multihead: Shuffle: (Output: " + output_name + ")").c_str()); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } else { + int head_size = hidden_out / head_number; + // [3, head_number, head_size, hidden_in] -> [head_number, 3, + // head_size, + // hidden_in] + auto transpose_weight_v2 = [](const float* src, float* dst, int three, + int head_number, int head_size, + int hidden_in) { + const int HH = head_size * hidden_in; + for (int i = 0; i < three; ++i) { + for (int n = 0; n < head_number; ++n) { + for (int hh = 0; hh < HH; ++hh) { + dst[n * three * HH + i * HH + hh] = + src[i * head_number * HH + n * HH + hh]; + } + } + } + }; + // [3, head_number, head_size] -> [head_number, 3, head_size] + auto transpose_bias_v2 = [](const float* src, float* dst, int N, + int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), weight_data, three, + head_number, head_size, hidden_in); + + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[weight_t->numel()]; + for (int i = 0; i < weight_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + weight{with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + static_cast(w_data), + static_cast(weight_t->numel())}; + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy(bias_data_tmp.data(), bias_data, + bias_t->numel() * sizeof(float)); + transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, + head_size); + + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + if (op_desc.HasAttr("Input_scale")) { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(input); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(input); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers " + "in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + } + + auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); + + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "2"); + assert(creator != nullptr); + int type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8 && + (engine_->precision() == AnalysisConfig::Precision::kInt8)) { + type = static_cast(nvinfer1::DataType::kINT8); + } + bool has_mask = true; + int var_seqlen = 1; + std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + plugin_inputs.emplace_back(mask_tensor); + if (engine_->Has("ernie_pos_name")) { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->Get("ernie_pos_name"))); + } else { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network() + ->getInput(2) + ->getName())); // cu_seqlens, eval_placeholder_2 + } + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } + } else { + PADDLE_ENFORCE_EQ( + input->getDimensions().nbDims, 3, + platform::errors::InvalidArgument( + "The Input dim of the SparseMultiheadMatMul should be 3, " + "but it's (%d) now.", + input->getDimensions().nbDims)); + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[weight_t->numel()]; + for (int i = 0; i < weight_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + + TensorRTEngine::Weight weight{with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + static_cast(w_data), + static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = 5; + reshape_before_fc_dim.d[0] = 0; + reshape_before_fc_dim.d[1] = 0; + reshape_before_fc_dim.d[2] = 0; + reshape_before_fc_dim.d[3] = 1; + reshape_before_fc_dim.d[4] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (op_desc.HasAttr("Input_scale")) { + engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), + in_scale); + } + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + ")") + .c_str()); + + // add layer fc + nvinfer1::ILayer* fc_layer = nullptr; + if (op_desc.HasAttr("Input_scale")) { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer.getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer.getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + } + fc_layer->setName( + ("sparse_multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // no need to add shuffle after fc, just change it in + // QkvToContextPluginDynamic + + // add qkv to context + int head_size = hidden_out / head_number; + float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_layer->getOutput(0)); + plugin_inputs.push_back(input_bias_qk); + + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + plugin::DynamicPluginTensorRT* plugin = + new plugin::QkvToContextPluginDynamic(hidden_in, head_number, + head_size, scale, with_fp16); + layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); + } + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static shape mode, which " + "is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " + "the shape information to run the dynamic shape mode.")); + } + RreplenishLayerAndOutput(layer, "sparse_multihead_matmul", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, MultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6ebca816a40fdd..c60c18e1dbef06 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -48,6 +48,8 @@ struct SimpleOpTypeSetTeller : public Teller { #if IS_TRT_VERSION_GE(8000) teller_set.insert("sparse_fc"); int8_teller_set.insert("sparse_fc"); + teller_set.insert("sparse_multihead_matmul"); + int8_teller_set.insert("sparse_multihead_matmul"); #endif } @@ -1738,9 +1740,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } #if IS_TRT_VERSION_GE(8000) - if (op_type == "sparse_fc") { + if (op_type == "sparse_fc" || op_type == "sparse_multihead_matmul") { if (!with_dynamic_shape) { - VLOG(3) << "the sparse_fc does not support static shape yet"; + VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support static shape yet"; return false; } } From c8ac573dc1aa96fdb03a034c8ab0d0c9e6a06f3a Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Mon, 23 May 2022 10:42:42 +0000 Subject: [PATCH 03/11] strange error --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 4 ++-- .../framework/ir/graph_pattern_detector.h | 20 ++++++++-------- ...dense_multihead_matmul_with_sparse_pass.cc | 24 +++++++++---------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a88d9f6fe11bbc..694c46e0b5f267 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -101,6 +101,7 @@ pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(mixed_precision_configure_pass inference) pass_library(replace_dense_with_sparse_pass inference) +pass_library(replace_dense_multihead_matmul_with_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index dcf58185930bb7..493f39897a0dcf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3449,7 +3449,7 @@ PDNode *patterns::DenseFC::operator()() { return fc_out; } -PDNode *patterns::DenseMultiheadMatmul::operator()() { +PDNode *patterns::MultiheadMatmul::operator()() { auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())->assert_is_op("multihead_matmul"); // Input auto *multihead_matmul_input = pattern->NewNode(multihead_matmul_input_repr()) @@ -3471,7 +3471,7 @@ PDNode *patterns::DenseMultiheadMatmul::operator()() { auto *multihead_matmul_out = pattern->NewNode(multihead_matmul_out_repr()) ->AsOutput() ->assert_is_op_output("multihead_matmul", "Out") - ->assert_is_only_output_of_op("fc"); + ->assert_is_only_output_of_op("multihead_matmul"); multihead_matmul->LinksFrom({multihead_matmul_input, multihead_matmul_weights, multihead_matmul_bias, multihead_matmul_biasqk}).LinksTo({multihead_matmul_out}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 4be8eb9140df5e..bcfca14c78a3fe 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1959,21 +1959,21 @@ struct DenseFC : public PatternBase { }; // -// \brief Pattern looking for dense multihead matmul fc. +// \brief Pattern looking for multihead matmul fc. // -struct DenseMultiheadMatmul : public PatternBase { - DenseMultiheadMatmul(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "dense_multihead_matmul") {} +struct MultiheadMatmul : public PatternBase { + MultiheadMatmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul") {} PDNode* operator()(); // declare operator node's name - PATTERN_DECL_NODE(matmul); - PATTERN_DECL_NODE(matmul_out); - PATTERN_DECL_NODE(matmul_input); - PATTERN_DECL_NODE(matmul_weights); - PATTERN_DECL_NODE(matmul_bias); - PATTERN_DECL_NODE(matmul_biasqk); + PATTERN_DECL_NODE(multihead_matmul); + PATTERN_DECL_NODE(multihead_matmul_out); + PATTERN_DECL_NODE(multihead_matmul_input); + PATTERN_DECL_NODE(multihead_matmul_weights); + PATTERN_DECL_NODE(multihead_matmul_bias); + PATTERN_DECL_NODE(multihead_matmul_biasqk); }; } // namespace patterns diff --git a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc index 15a1ca15c61641..cb4cc14bfe1959 100644 --- a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc +++ b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc @@ -47,25 +47,25 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; - patterns::DenseMultiheadMatmul dense_multihead_matmul_pattern(gpd.mutable_pattern(), + patterns::MultiheadMatmul multihead_matmul_pattern(gpd.mutable_pattern(), "dense_multihead_matmul_replace_pass"); - dense_multihead_matmul_pattern(); - int found_dense_multihead_matmul_count = 0; + multihead_matmul_pattern(); + int found_multihead_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { - VLOG(4) << "Replace dense multihead matmul with sparse_multihead_matmul."; + VLOG(4) << "Replace dense multihead matmul with sparse multihead matmul."; /* if (!IsCompat(subgraph, g)) { LOG(WARNING) << "Pass in op compat failed."; return; }*/ - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, dense_multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, dense_multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, dense_multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, multihead_matmul_weights, dense_multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, dense_multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, dense_multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, multihead_matmul_weights, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, multihead_matmul_pattern); auto *multihead_matmul_op = multihead_matmul->Op(); auto w_name = multihead_matmul_op->Input("W")[0]; @@ -107,12 +107,12 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { IR_NODE_LINK_TO(multihead_matmul_bias, sparse_multihead_matmul_node); IR_NODE_LINK_TO(multihead_matmul_biasqk, sparse_multihead_matmul_node); IR_NODE_LINK_TO(sparse_multihead_matmul_node, multihead_matmul_out); - found_dense_multihead_matmul_count++; + found_multihead_matmul_count++; } }; gpd(graph, handler); - AddStatis(found_dense_multihead_matmul_count); + AddStatis(found_multihead_matmul_count); } } // namespace ir From e574ad92bd5dc0c46c9c607546dbcd6a6bc44def Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Mon, 23 May 2022 11:28:53 +0000 Subject: [PATCH 04/11] compilation failed --- .../convert/sparse_multihead_matmul_op.cc | 59 ++++++------------- 1 file changed, 18 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc index 3f36a5743f2702..6933eecd5a12df 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See @@ -18,7 +21,7 @@ namespace inference { namespace tensorrt { class SparseMultiheadMatMulOpConverter : public OpConverter { - public: + public: plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, TensorRTEngine::Weight* bias, nvinfer1::DataType type, @@ -100,21 +103,18 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_t->numel())}; if (engine_->with_interleaved()) { - VLOG(4) << "fused sparse_multihead_matmul op: use_oss and with_interleaved"; + VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; if (!op_desc.HasAttr("Input_scale")) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); } nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; - plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( - &weight, &bias, nvinfer1::DataType::kINT8, n); - std::vector plugin_inputs; - plugin_inputs.emplace_back(input); - fc_layer = engine_->network()->addPluginV2( - plugin_inputs.data(), plugin_inputs.size(), *plugin); + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); fc_layer->setName( - ("SparseMultihead: SPMM: (Output: " + + ("Multihead: Convolution/FullyConnected: (Output: " + output_name + ")") .c_str()); PADDLE_ENFORCE_EQ( @@ -177,7 +177,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { plugin_inputs.emplace_back( shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 shuffle_layer->setName( - ("Sparse multihead: Shuffle: (Output: " + output_name + ")").c_str()); + ("Multihead: Shuffle: (Output: " + output_name + ")").c_str()); auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); layer = plugin_layer; @@ -215,21 +215,6 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { transpose_weight_v2(weight_data_tmp.data(), weight_data, three, head_number, head_size, hidden_in); - half* half_data = nullptr; - void* w_data = nullptr; - if (with_fp16) { - half_data = new half[weight_t->numel()]; - for (int i = 0; i < weight_t->numel(); i++) { - half_data[i] = static_cast(weight_data[i]); - } - w_data = static_cast(half_data); - } else { - w_data = static_cast(weight_data); - } - weight{with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, - static_cast(w_data), - static_cast(weight_t->numel())}; - std::vector bias_data_tmp; bias_data_tmp.reserve(bias_t->numel()); memcpy(bias_data_tmp.data(), bias_data, @@ -240,20 +225,12 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; if (op_desc.HasAttr("Input_scale")) { - plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( - &weight, &bias, nvinfer1::DataType::kINT8, n); - std::vector plugin_inputs; - plugin_inputs.emplace_back(input); - fc_layer = engine_->network()->addPluginV2( - plugin_inputs.data(), plugin_inputs.size(), *plugin); + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); } else { - plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( - &weight, &bias, - with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, n); - std::vector plugin_inputs; - plugin_inputs.emplace_back(input); - fc_layer = engine_->network()->addPluginV2( - plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, + weight, bias); } if (op_desc.HasAttr("fc_out_threshold")) { @@ -392,7 +369,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( &weight, &bias, nvinfer1::DataType::kINT8, n); std::vector plugin_inputs; - plugin_inputs.emplace_back(reshape_before_fc_layer.getOutput(0)); + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); fc_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); } else { @@ -401,7 +378,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, n); std::vector plugin_inputs; - plugin_inputs.emplace_back(reshape_before_fc_layer.getOutput(0)); + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); fc_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); } @@ -453,4 +430,4 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, MultiheadMatMulOpConverter); +REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, SparseMultiheadMatMulOpConverter); From 523191806680424aa79031bcdc77fcb274e8b2ac Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Mon, 23 May 2022 14:45:47 +0000 Subject: [PATCH 05/11] style check --- .../framework/ir/graph_pattern_detector.cc | 45 +++++++++++-------- ...dense_multihead_matmul_with_sparse_pass.cc | 38 ++++++++++------ ..._dense_multihead_matmul_with_sparse_pass.h | 3 +- .../inference/api/paddle_pass_builder.cc | 42 ++++++++--------- .../convert/sparse_multihead_matmul_op.cc | 23 +++++----- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++- .../inference/tensorrt/plugin/spmm_plugin.cu | 14 +++--- 7 files changed, 98 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 493f39897a0dcf..eb15f80aeceb55 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3450,30 +3450,39 @@ PDNode *patterns::DenseFC::operator()() { } PDNode *patterns::MultiheadMatmul::operator()() { - auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())->assert_is_op("multihead_matmul"); + auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr()) + ->assert_is_op("multihead_matmul"); // Input - auto *multihead_matmul_input = pattern->NewNode(multihead_matmul_input_repr()) - ->AsInput() - ->assert_is_op_input("multihead_matmul", "Input"); + auto *multihead_matmul_input = + pattern->NewNode(multihead_matmul_input_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Input"); // Filter - auto *multihead_matmul_weights = pattern->NewNode(multihead_matmul_weights_repr()) - ->AsInput() - ->assert_is_op_input("multihead_matmul", "W"); + auto *multihead_matmul_weights = + pattern->NewNode(multihead_matmul_weights_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "W"); // Bias - auto *multihead_matmul_bias = pattern->NewNode(multihead_matmul_bias_repr()) - ->AsInput() - ->assert_is_op_input("multihead_matmul", "Bias"); + auto *multihead_matmul_bias = + pattern->NewNode(multihead_matmul_bias_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Bias"); // BiasQK - auto *multihead_matmul_biasqk = pattern->NewNode(multihead_matmul_biasqk_repr()) - ->AsInput() - ->assert_is_op_input("multihead_matmul", "BiasQK"); + auto *multihead_matmul_biasqk = + pattern->NewNode(multihead_matmul_biasqk_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "BiasQK"); // Output - auto *multihead_matmul_out = pattern->NewNode(multihead_matmul_out_repr()) - ->AsOutput() - ->assert_is_op_output("multihead_matmul", "Out") - ->assert_is_only_output_of_op("multihead_matmul"); + auto *multihead_matmul_out = + pattern->NewNode(multihead_matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("multihead_matmul", "Out") + ->assert_is_only_output_of_op("multihead_matmul"); - multihead_matmul->LinksFrom({multihead_matmul_input, multihead_matmul_weights, multihead_matmul_bias, multihead_matmul_biasqk}).LinksTo({multihead_matmul_out}); + multihead_matmul + ->LinksFrom({multihead_matmul_input, multihead_matmul_weights, + multihead_matmul_bias, multihead_matmul_biasqk}) + .LinksTo({multihead_matmul_out}); return multihead_matmul_out; } diff --git a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc index cb4cc14bfe1959..9fdb5e14da1533 100644 --- a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc +++ b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.cc @@ -20,7 +20,8 @@ namespace paddle { namespace framework { namespace ir { -ReplaceDenseMultiheadMatmulWithSparsePass::ReplaceDenseMultiheadMatmulWithSparsePass() { +ReplaceDenseMultiheadMatmulWithSparsePass:: + ReplaceDenseMultiheadMatmulWithSparsePass() { AddOpCompat(OpCompat("multihead_matmul")) .AddInput("Input") .IsTensor() @@ -47,8 +48,8 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; - patterns::MultiheadMatmul multihead_matmul_pattern(gpd.mutable_pattern(), - "dense_multihead_matmul_replace_pass"); + patterns::MultiheadMatmul multihead_matmul_pattern( + gpd.mutable_pattern(), "dense_multihead_matmul_replace_pass"); multihead_matmul_pattern(); int found_multihead_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, @@ -60,12 +61,19 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { return; }*/ - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, multihead_matmul_weights, multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, multihead_matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, + multihead_matmul_weights, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, + multihead_matmul_pattern); auto *multihead_matmul_op = multihead_matmul->Op(); auto w_name = multihead_matmul_op->Input("W")[0]; @@ -84,19 +92,23 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const { desc.SetAttr("alpha", multihead_matmul_op->GetAttr("alpha")); desc.SetAttr("head_number", multihead_matmul_op->GetAttr("head_number")); if (multihead_matmul_op->HasAttr("Input_scale")) { - desc.SetAttr("Input_scale", multihead_matmul_op->GetAttr("Input_scale")); + desc.SetAttr("Input_scale", + multihead_matmul_op->GetAttr("Input_scale")); } if (multihead_matmul_op->HasAttr("fc_out_threshold")) { - desc.SetAttr("fc_out_threshold", multihead_matmul_op->GetAttr("fc_out_threshold")); + desc.SetAttr("fc_out_threshold", + multihead_matmul_op->GetAttr("fc_out_threshold")); } if (multihead_matmul_op->HasAttr("qkv2context_plugin_int8")) { - desc.SetAttr("qkv2context_plugin_int8", multihead_matmul_op->GetAttr("qkv2context_plugin_int8")); + desc.SetAttr("qkv2context_plugin_int8", + multihead_matmul_op->GetAttr("qkv2context_plugin_int8")); } if (multihead_matmul_op->HasAttr("dp_probs")) { desc.SetAttr("dp_probs", multihead_matmul_op->GetAttr("dp_probs")); } if (multihead_matmul_op->HasAttr("out_threshold")) { - desc.SetAttr("out_threshold", multihead_matmul_op->GetAttr("out_threshold")); + desc.SetAttr("out_threshold", + multihead_matmul_op->GetAttr("out_threshold")); } desc.Flush(); GraphSafeRemoveNodes(g, {multihead_matmul}); diff --git a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h index 4c43e9ba4efbb9..abfbfb6b7b13dc 100644 --- a/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h +++ b/paddle/fluid/framework/ir/replace_dense_multihead_matmul_with_sparse_pass.h @@ -37,7 +37,8 @@ class ReplaceDenseMultiheadMatmulWithSparsePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"replace_dense_multihead_matmul_with_sparse_pass"}; + const std::string name_scope_{ + "replace_dense_multihead_matmul_with_sparse_pass"}; }; } // namespace ir diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9d7ff1ff2048d9..18a4a0f5f56ae7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -91,27 +91,27 @@ const std::vector kTRTSubgraphPasses({ "delete_quant_dequant_linear_op_pass", // "add_support_int8_pass", // // "fc_fuse_pass", // - "simplify_with_basic_ops_pass", // - "embedding_eltwise_layernorm_fuse_pass", // - "preln_embedding_eltwise_layernorm_fuse_pass", // - "multihead_matmul_fuse_pass_v2", // - "multihead_matmul_fuse_pass_v3", // - "skip_layernorm_fuse_pass", // - "preln_skip_layernorm_fuse_pass", // - "conv_bn_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", // - "trt_squeeze2_matmul_fuse_pass", // - "trt_reshape2_matmul_fuse_pass", // - "trt_flatten2_matmul_fuse_pass", // - "trt_map_matmul_v2_to_mul_pass", // - "trt_map_matmul_v2_to_matmul_pass", // - "trt_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - "replace_dense_with_sparse_pass", // - "replace_dense_multihead_matmul_with_sparse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "simplify_with_basic_ops_pass", // + "embedding_eltwise_layernorm_fuse_pass", // + "preln_embedding_eltwise_layernorm_fuse_pass", // + "multihead_matmul_fuse_pass_v2", // + "multihead_matmul_fuse_pass_v3", // + "skip_layernorm_fuse_pass", // + "preln_skip_layernorm_fuse_pass", // + "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_reshape2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + "replace_dense_with_sparse_pass", // + "replace_dense_multihead_matmul_with_sparse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc index 6933eecd5a12df..e1bf903c1f5d7e 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -35,7 +35,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding tensorrt " + VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding " + "tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -335,9 +336,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { w_data = static_cast(weight_data); } - TensorRTEngine::Weight weight{with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, - static_cast(w_data), - static_cast(weight_t->numel())}; + TensorRTEngine::Weight weight{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + static_cast(w_data), static_cast(weight_t->numel())}; weight.dims.assign({n, m}); TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, @@ -360,22 +361,23 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { } reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); reshape_before_fc_layer->setName( - ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + ")") + ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + + ")") .c_str()); // add layer fc nvinfer1::ILayer* fc_layer = nullptr; if (op_desc.HasAttr("Input_scale")) { - plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( - &weight, &bias, nvinfer1::DataType::kINT8, n); + plugin::SpmmPluginDynamic* plugin = + new_spmm_plugin(&weight, &bias, nvinfer1::DataType::kINT8, n); std::vector plugin_inputs; plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); fc_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); } else { plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( - &weight, &bias, - with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + &weight, &bias, with_fp16 ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT, n); std::vector plugin_inputs; plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); @@ -430,4 +432,5 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, SparseMultiheadMatMulOpConverter); +REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, + SparseMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index c60c18e1dbef06..5eaa2fd27ba8bc 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -49,7 +49,7 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("sparse_fc"); int8_teller_set.insert("sparse_fc"); teller_set.insert("sparse_multihead_matmul"); - int8_teller_set.insert("sparse_multihead_matmul"); + int8_teller_set.insert("sparse_multihead_matmul"); #endif } @@ -1742,7 +1742,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, #if IS_TRT_VERSION_GE(8000) if (op_type == "sparse_fc" || op_type == "sparse_multihead_matmul") { if (!with_dynamic_shape) { - VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support static shape yet"; + VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support " + "static shape yet"; return false; } } diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index 389a0807a5d3da..bb3ba9b7fe93dc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -377,7 +377,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, compressed_size); cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size, cudaMemcpyHostToDevice); - + has_bias_ = (bias != nullptr); if (has_bias_) { // Each plugin has a copy of bias @@ -552,7 +552,7 @@ void SpmmPluginDynamic::configurePlugin( platform::errors::InvalidArgument( "precision_ should be equal to inputs[0].desc.type")); const auto& inDims0 = inputs[0].desc.dims; - if (inDims0.nbDims==5) { + if (inDims0.nbDims == 5) { PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( "inDims0.nbDims should be 5")); PADDLE_ENFORCE_EQ(k_, inDims0.d[2], @@ -565,7 +565,7 @@ void SpmmPluginDynamic::configurePlugin( const int BS = inputs->max.d[0]; const int Seq = inputs->max.d[1]; m_max_ = BS * Seq; - } else if (inDims0.nbDims==4) { + } else if (inDims0.nbDims == 4) { PADDLE_ENFORCE_EQ(inDims0.nbDims, 4, platform::errors::InvalidArgument( "inDims0.nbDims should be 4")); PADDLE_ENFORCE_EQ(k_, inDims0.d[1], @@ -576,7 +576,7 @@ void SpmmPluginDynamic::configurePlugin( PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( "inDims0.d[3] should be 1")); const int BS_Seq = inputs->max.d[0]; - m_max_ = BS_Seq; + m_max_ = BS_Seq; } // The optimal algorighm id is for m = m_max_ // To Do: configurePlugin takes time when m is changed @@ -642,15 +642,15 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, PADDLE_ENFORCE_EQ(is_configured_, true, platform::errors::InvalidArgument( "The plugin is not configured before enqueue")); - if (inputDesc->dims.nbDims==5){ + if (inputDesc->dims.nbDims == 5) { PADDLE_ENFORCE_EQ( k_, inputDesc->dims.d[2], platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); - } else if (inputDesc->dims.nbDims==4) { + } else if (inputDesc->dims.nbDims == 4) { PADDLE_ENFORCE_EQ( k_, inputDesc->dims.d[1], platform::errors::InvalidArgument("k_ == inputDesc->dims.d[1]")); - } + } float alpha = 1.0f; float beta = 0.0f; if (inputDesc->type == nvinfer1::DataType::kFLOAT) { From 0af03939376936eb69f1eb9c42ac907f5403355b Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Thu, 26 May 2022 04:36:13 +0000 Subject: [PATCH 06/11] fix error --- paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc | 2 +- .../inference/tensorrt/convert/sparse_multihead_matmul_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index a8595d55b31b05..4a5947778056a1 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -864,7 +864,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, auto* mul0_op_desc = mul0->Op(); // all mul op has same input. - if (multihead_op_desc.HasAttr("Input_scale")) { + if (mul0_op_desc->HasAttr("Input_scale")) { multihead_op_desc.SetAttr("Input_scale", mul0_op_desc->GetAttr("Input_scale")); } diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc index e1bf903c1f5d7e..ad7f8e79a75146 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -423,7 +423,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "the shape information to run the dynamic shape mode.")); } - RreplenishLayerAndOutput(layer, "sparse_multihead_matmul", {output_name}, + RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, test_mode); } }; From 3340d6bbf84ba5b370cf97960c0267bc174390f3 Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Thu, 26 May 2022 12:58:00 +0000 Subject: [PATCH 07/11] shared_ptr compilation passed --- .../inference/tensorrt/plugin/spmm_plugin.cu | 35 ++++++++++--------- .../inference/tensorrt/plugin/spmm_plugin.h | 6 +++- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index bb3ba9b7fe93dc..240ddf407de244 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -63,6 +63,8 @@ inline void deserialize_value_size(void const** buffer, size_t* buffer_size, inline float round_scale(float x) { return std::floor(x + 0.5f); } +inline void cudaFreeFunc(void* p) { if(p) { cudaFree(p); } } + inline void convertAndCopy(const nvinfer1::Weights& src, nvinfer1::DataType type, void* dest) { PADDLE_ENFORCE_EQ(src.type == nvinfer1::DataType::kFLOAT || @@ -252,6 +254,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, weight_scale_(1.0f), weight_compressed_(nullptr), weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), compressed_size_(0), has_bias_(false), bias_(nullptr), @@ -310,12 +313,12 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, cudaMemcpy(weight_dev, weight_host.data(), precision_size_ * weight.count, cudaMemcpyHostToDevice); } - spmm_context_.compressMatB(out_dim_, k_, convertTrtType(precision_), weight_dev, &weight_compressed_dev_, &compressed_size_); weight_compressed_ = new char[compressed_size_]; - cudaMemcpy(weight_compressed_, weight_compressed_dev_, compressed_size_, + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); + cudaMemcpy(weight_compressed_, weight_compressed_dev_global_.get(), compressed_size_, cudaMemcpyDeviceToHost); has_bias_ = (bias.count != 0); @@ -352,7 +355,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, optim_alg_(optim_alg), weight_scale_(1.0f), weight_compressed_(nullptr), - weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), compressed_size_(compressed_size), has_bias_(false), bias_(nullptr), @@ -373,11 +376,6 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, std::copy_n(static_cast(weight_compressed), compressed_size, static_cast(weight_compressed_)); - cudaMalloc(reinterpret_cast(&weight_compressed_dev_), - compressed_size); - cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size, - cudaMemcpyHostToDevice); - has_bias_ = (bias != nullptr); if (has_bias_) { // Each plugin has a copy of bias @@ -403,7 +401,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, size_t length) : layer_name_(name), weight_compressed_(nullptr), - weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), bias_(nullptr), bias_dev_(nullptr) { DeserializeValue(&data, &length, &precision_); @@ -424,9 +422,10 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, "Deserialize data should be configured")); weight_compressed_ = new char[compressed_size_]; deserialize_value_size(&data, &length, weight_compressed_, compressed_size_); - cudaMalloc(reinterpret_cast(&weight_compressed_dev_), + //MEM: how to deal with deserialization? + cudaMalloc(reinterpret_cast(weight_compressed_dev_global_.get()), compressed_size_); - cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, + cudaMemcpy(weight_compressed_dev_global_.get(), weight_compressed_, compressed_size_, cudaMemcpyHostToDevice); if (has_bias_) { @@ -451,8 +450,8 @@ nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept { weight_compressed_, compressed_size_, bias_, is_configured_, m_max_, optim_alg_, activation_); p->weight_scale_ = weight_scale_; + p->weight_compressed_dev_global_ = weight_compressed_dev_global_; p->setPluginNamespace(namespace_.c_str()); - return p; } catch (const std::exception& e) { std::cerr << e.what() << std::endl; @@ -614,7 +613,7 @@ void SpmmPluginDynamic::configurePlugin( spmm_context_.workspace_size); paddle::platform::dynload::cusparseLtMatmulSearch( &spmm_context_.handle, &spmm_context_.plan, &alpha, dA, - weight_compressed_dev_, &beta, dC, dC, d_workspace, nullptr, 0); + weight_compressed_dev_global_.get(), &beta, dC, dC, d_workspace, nullptr, 0); paddle::platform::dynload::cusparseLtMatmulAlgGetAttribute( &spmm_context_.handle, &spmm_context_.alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &optim_alg_, sizeof(optim_alg_)); @@ -658,14 +657,14 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, auto* output = static_cast(outputs[0]); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kHALF) { const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kINT8) { alpha = inputDesc->scale * weight_scale_ / outputDesc->scale; @@ -673,7 +672,7 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, auto* output = static_cast(outputs[0]); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1); return status != CUSPARSE_STATUS_SUCCESS; } else { PADDLE_THROW(paddle::platform::errors::Fatal( @@ -749,7 +748,9 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { void SpmmPluginDynamic::destroy() noexcept { delete[] reinterpret_cast(weight_compressed_); - cudaFree(weight_compressed_dev_); + //MEM: + // cudaFree(weight_compressed_dev_); + weight_compressed_dev_global_.reset(); if (has_bias_) { cudaFree(bias_dev_); } diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h index a7edb8dedfa7fd..404fbff18b8c2e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h @@ -39,6 +39,8 @@ #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/platform/dynload/cusparseLt.h" +using namespace std; + namespace paddle { namespace inference { namespace tensorrt { @@ -77,6 +79,7 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; @@ -128,7 +131,8 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { int optim_alg_; // the index of optimal algorithm float weight_scale_; // record the weight scale from constructor void* weight_compressed_; // host compressed weight - void* weight_compressed_dev_; // device compressed weight + void* weight_compressed_dev_; // device compressed weight + shared_ptr weight_compressed_dev_global_; // shared pointer to the device compressed weight size_t compressed_size_; // size of compressed weight bool has_bias_; // there is bias or not void* bias_; // host bias From 9001a73b346fc3af2a214261c59286583223032c Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Fri, 27 May 2022 03:12:36 +0000 Subject: [PATCH 08/11] shared_ptr is nullptr in enqueue --- .../inference/tensorrt/plugin/spmm_plugin.cu | 52 +++++++++++++++++-- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index 240ddf407de244..3aecd3795cea2a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -118,6 +118,7 @@ void SpmmPluginDynamic::cusparseLtContext::init( 4. Init algorithm selection descriptor (alg_sel) 5. Init plan descriptor (plan) */ + std::cout << "init context" << std::endl; PADDLE_ENFORCE_EQ( is_initialized, false, platform::errors::InvalidArgument( @@ -204,6 +205,7 @@ void SpmmPluginDynamic::cusparseLtContext::setAlgo(int alg) { } void SpmmPluginDynamic::cusparseLtContext::destroy() { + std::cout << "destroy context" << std::endl; PADDLE_ENFORCE_EQ(is_initialized, true, platform::errors::InvalidArgument( "cusparseLtContext is destroy before init")); @@ -217,6 +219,7 @@ void SpmmPluginDynamic::cusparseLtContext::destroy() { void SpmmPluginDynamic::cusparseLtContext::compressMatB( int n, int k, cudaDataType_t type, void* src, void** dest, size_t* compressed_size) { + std::cout << "compress matB" << std::endl; PADDLE_ENFORCE_EQ( is_initialized, false, platform::errors::InvalidArgument( @@ -268,6 +271,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, 5. Copy the compressed weight to host 6. Convert bias precision and copy (on host) */ + std::cout << "new plugin" << std::endl; precision_size_ = getElementSize(precision); element_size_ = (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); @@ -318,8 +322,14 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, &compressed_size_); weight_compressed_ = new char[compressed_size_]; weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); + std::cout << "initial count: " << weight_compressed_dev_global_.use_count() << std::endl; cudaMemcpy(weight_compressed_, weight_compressed_dev_global_.get(), compressed_size_, cudaMemcpyDeviceToHost); + std::cout << "compressed weight:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); + } + std::cout << std::endl; has_bias_ = (bias.count != 0); if (has_bias_) { @@ -368,6 +378,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, 4. (Configured) Copy the bias to device 5. (Configured) Init cuSPARSELt descriptors */ + std::cout << "clone plugin" << std::endl; precision_size_ = getElementSize(precision); element_size_ = (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); @@ -404,6 +415,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, weight_compressed_dev_global_(nullptr), bias_(nullptr), bias_dev_(nullptr) { + std::cout << "deserialization" << std::endl; DeserializeValue(&data, &length, &precision_); DeserializeValue(&data, &length, &precision_size_); DeserializeValue(&data, &length, &element_size_); @@ -423,11 +435,18 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, weight_compressed_ = new char[compressed_size_]; deserialize_value_size(&data, &length, weight_compressed_, compressed_size_); //MEM: how to deal with deserialization? - cudaMalloc(reinterpret_cast(weight_compressed_dev_global_.get()), + auto* p_tmp = weight_compressed_dev_global_.get(); + cudaMalloc(reinterpret_cast(&p_tmp), compressed_size_); cudaMemcpy(weight_compressed_dev_global_.get(), weight_compressed_, compressed_size_, cudaMemcpyHostToDevice); + std::cout << "compressed weight:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); + } + std::cout << std::endl; + if (has_bias_) { bias_ = new float[out_dim_]; deserialize_value_size(&data, &length, bias_, sizeof(float) * out_dim_); @@ -540,6 +559,7 @@ void SpmmPluginDynamic::configurePlugin( 2. Copy the bias to device 3. Search the optimal algorithm */ + std::cout << "configure plugin" << std::endl; try { PADDLE_ENFORCE_EQ(nbInputs, 1, platform::errors::InvalidArgument( @@ -638,6 +658,7 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const void* const* inputs, void* const* outputs, void* workSpace, cudaStream_t stream) noexcept { try { + std::cout << "enqueue" << std::endl; PADDLE_ENFORCE_EQ(is_configured_, true, platform::errors::InvalidArgument( "The plugin is not configured before enqueue")); @@ -655,16 +676,34 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, if (inputDesc->type == nvinfer1::DataType::kFLOAT) { const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); + char* test_weight = new char[compressed_size_]; + cudaMemcpy(weight_compressed_dev_global_.get(), test_weight, compressed_size_, + cudaMemcpyHostToDevice); + std::cout << "compressed weight:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); + } + std::cout << std::endl; + + std::cout << "weight from shared ptr:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast(test_weight)[i]); + } + std::cout << std::endl; + + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kHALF) { const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kINT8) { alpha = inputDesc->scale * weight_scale_ / outputDesc->scale; @@ -736,7 +775,7 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { SerializeValue(&buffer, compressed_size_); SerializeValue(&buffer, has_bias_); SerializeValue(&buffer, activation_); - + std::cout << "serialize" << std::endl; char* d = static_cast(buffer); std::copy_n(static_cast(weight_compressed_), compressed_size_, d); @@ -747,10 +786,13 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { } void SpmmPluginDynamic::destroy() noexcept { + std::cout << "destroy plugin" << std::endl; delete[] reinterpret_cast(weight_compressed_); //MEM: // cudaFree(weight_compressed_dev_); - weight_compressed_dev_global_.reset(); + // std::cout << "current use cout before this destroy: " << weight_compressed_dev_global_.use_count() << std::endl; + // weight_compressed_dev_global_.reset(); + std::cout << "current use cout after this destroy: " << weight_compressed_dev_global_.use_count() << std::endl; if (has_bias_) { cudaFree(bias_dev_); } From 291bc0b7da5ad3d1638224e7f66e85d8d0e59327 Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Fri, 27 May 2022 03:19:27 +0000 Subject: [PATCH 09/11] shared_ptr is nullptr in enqueue --- paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index 3aecd3795cea2a..913d2cf24a9add 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -678,8 +678,8 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, auto* output = static_cast(outputs[0]); auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); char* test_weight = new char[compressed_size_]; - cudaMemcpy(weight_compressed_dev_global_.get(), test_weight, compressed_size_, - cudaMemcpyHostToDevice); + cudaMemcpy(test_weight, weight_compressed_dev_global_.get(), compressed_size_, + cudaMemcpyDeviceToHost); std::cout << "compressed weight:"; for(int i=0; i<10; i++) { std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); @@ -691,8 +691,6 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, std::cout << " " << static_cast(reinterpret_cast(test_weight)[i]); } std::cout << std::endl; - - cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, 1); From 865b673b66833bc6f6ec766ab5339e396828ea22 Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Fri, 27 May 2022 03:52:02 +0000 Subject: [PATCH 10/11] UT passed --- .../inference/tensorrt/plugin/spmm_plugin.cu | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index 913d2cf24a9add..28c9415d8f2a33 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -412,6 +412,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, size_t length) : layer_name_(name), weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), weight_compressed_dev_global_(nullptr), bias_(nullptr), bias_dev_(nullptr) { @@ -435,18 +436,26 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, weight_compressed_ = new char[compressed_size_]; deserialize_value_size(&data, &length, weight_compressed_, compressed_size_); //MEM: how to deal with deserialization? - auto* p_tmp = weight_compressed_dev_global_.get(); - cudaMalloc(reinterpret_cast(&p_tmp), - compressed_size_); - cudaMemcpy(weight_compressed_dev_global_.get(), weight_compressed_, compressed_size_, - cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&weight_compressed_dev_), compressed_size_); + cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, cudaMemcpyHostToDevice); + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); - std::cout << "compressed weight:"; + char* test_weight = new char[compressed_size_]; + cudaMemcpy(test_weight, weight_compressed_dev_global_.get(), compressed_size_, + cudaMemcpyDeviceToHost); + std::cout << "compressed weight in deserial:"; for(int i=0; i<10; i++) { std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); } std::cout << std::endl; + std::cout << "weight from shared ptr in deserial:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast(test_weight)[i]); + } + std::cout << std::endl; + + if (has_bias_) { bias_ = new float[out_dim_]; deserialize_value_size(&data, &length, bias_, sizeof(float) * out_dim_); From e9ffce211c80a3f87fa764554eaf89bbde013643 Mon Sep 17 00:00:00 2001 From: minghaoBD Date: Mon, 30 May 2022 04:01:13 +0000 Subject: [PATCH 11/11] nan fp16 output --- .../inference/tensorrt/plugin/spmm_plugin.cu | 42 ++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index 28c9415d8f2a33..7116643060e72a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -327,7 +327,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, cudaMemcpyDeviceToHost); std::cout << "compressed weight:"; for(int i=0; i<10; i++) { - std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); + std::cout << " " << static_cast(reinterpret_cast<__half*>(weight_compressed_)[i]); } std::cout << std::endl; @@ -342,7 +342,12 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, "SpmmPluginDynamic only supports FLOAT bias")); bias_ = new float[out_dim_]; - convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); + std::cout << "overwriting bias!!!!!!!!" << std::endl; + for (int i=0; i(bias_)[i] = 0.0; + } + // std::copy_n(static_cast(bias.values), bias.count, static_cast(bias_)); + // convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); } cudaFree(weight_dev); @@ -445,13 +450,13 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, cudaMemcpyDeviceToHost); std::cout << "compressed weight in deserial:"; for(int i=0; i<10; i++) { - std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); + std::cout << " " << static_cast(reinterpret_cast<__half*>(weight_compressed_)[i]); } std::cout << std::endl; std::cout << "weight from shared ptr in deserial:"; for(int i=0; i<10; i++) { - std::cout << " " << static_cast(reinterpret_cast(test_weight)[i]); + std::cout << " " << static_cast(reinterpret_cast<__half*>(test_weight)[i]); } std::cout << std::endl; @@ -555,7 +560,7 @@ bool SpmmPluginDynamic::supportsFormatCombination( (in.format == nvinfer1::TensorFormat::kLINEAR); } const nvinfer1::PluginTensorDesc& prev = inOut[pos - 1]; - + return in.type == prev.type && in.format == prev.format; } @@ -689,6 +694,7 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, char* test_weight = new char[compressed_size_]; cudaMemcpy(test_weight, weight_compressed_dev_global_.get(), compressed_size_, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); std::cout << "compressed weight:"; for(int i=0; i<10; i++) { std::cout << " " << static_cast(reinterpret_cast(weight_compressed_)[i]); @@ -708,9 +714,35 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); + char* output_host = new char[512]; + cudaDeviceSynchronize(); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, 1); + cudaDeviceSynchronize(); + cudaMemcpy(output_host, output, 512, + cudaMemcpyDeviceToHost); + std::cout << "output:"; + for (int i=0; i<20; i++) { + std::cout << " " << static_cast(reinterpret_cast<__half*>(output_host)[i]); + } + std::cout << std::endl; + + char* test_weight = new char[compressed_size_]; + cudaMemcpy(test_weight, weight_compressed_dev_global_.get(), compressed_size_, + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + std::cout << "compressed weight:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast<__half*>(weight_compressed_)[i]); + } + std::cout << std::endl; + + std::cout << "weight from shared ptr:"; + for(int i=0; i<10; i++) { + std::cout << " " << static_cast(reinterpret_cast<__half*>(test_weight)[i]); + } + std::cout << std::endl; return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kINT8) { alpha = inputDesc->scale * weight_scale_ / outputDesc->scale;