diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index a05a096daf928c..d5acfcc0ec7757 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -96,7 +96,8 @@ inline bool VarNodeHasDtype(Node* var_node) { auto type = var_node->Var()->GetType(); return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || - (type == VarType::VOCAB); + (type == VarType::VOCAB) || (type == VarType::SPARSE_COO) || + (type == VarType::SPARSE_CSR); } inline bool IsFP32(VarType::Type type) { return type == VarType::FP32; } @@ -123,12 +124,21 @@ void DoInsertCastOp(Graph* graph, const std::string& x_name, const std::string& out_name, const int in_dtype, - const int out_dtype) { - desc.SetType("cast"); - desc.SetInput("X", {x_name}); - desc.SetOutput("Out", {out_name}); - desc.SetAttr("in_dtype", in_dtype); - desc.SetAttr("out_dtype", out_dtype); + const int out_dtype, + const VarType::Type t) { + if (t == VarType::SPARSE_COO || t == VarType::SPARSE_CSR) { + desc.SetType("sparse_cast"); + desc.SetInput("x", {x_name}); + desc.SetOutput("out", {out_name}); + desc.SetAttr("index_dtype", -1); + desc.SetAttr("value_dtype", to_type); + } else { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + } desc.SetAttr("use_mkldnn", false); desc.SetAttr("with_quant_attr", false); desc.Flush(); @@ -140,17 +150,21 @@ void DoInsertCastOp(Graph* graph, std::string cast_output_name = var_node->Var()->Name() + "_cast_auto_mixed.tmp_" + std::to_string((*suffix)++); + VarType::Type var_type = var_node->Var()->GetType(); framework::OpDesc cast_op_desc(block_desc); update_cast_desc(cast_op_desc, cast_input_name, cast_output_name, static_cast(from_type), - static_cast(to_type)); + static_cast(to_type), + var_type); auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); auto* cast_output_vardesc = block_desc->Var(cast_output_name); + cast_output_vardesc->SetType(var_type); cast_output_vardesc->SetPersistable(false); cast_output_vardesc->SetDataType(to_type); cast_output_vardesc->SetShape(var_node->Var()->GetShape()); + cast_output_vardesc->Flush(); auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); IR_NODE_LINK_TO(cast_op_node, cast_output_node); (*cache)[var_node] = cast_output_node; @@ -452,8 +466,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { } } - // if op's input var and output var is not dense tensor, the op should - // not run at low precision. + // op's input var and output var only support + // dense/sparse_coo/sparse_csr tensor. for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name()); @@ -461,7 +475,9 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = support_low_precision && - (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); + (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR || + real_in_var_node->Var()->GetType() == VarType::SPARSE_COO || + real_in_var_node->Var()->GetType() == VarType::SPARSE_CSR); } for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); @@ -470,7 +486,9 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = support_low_precision && - (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); + (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR || + real_out_var_node->Var()->GetType() == VarType::SPARSE_COO || + real_out_var_node->Var()->GetType() == VarType::SPARSE_CSR); } } @@ -634,6 +652,23 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") { + auto vecs = op_desc->Input("bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") { auto vecs = op_desc->Input("Bias"); if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { @@ -728,6 +763,27 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") { + auto vecs = op_desc->Output("mean_out"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("variance_out"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("saved_mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("saved_variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("reserve_space"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } if (backend_ == phi::Backend::XPU) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 55fc19ad2be1c3..84b39dce2f7ab9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -96,6 +96,12 @@ static DDim GetDimsDebug(const Scope& scope, } } else if (var->IsType()) { return DDim({static_cast(var->Get().size())}); + } else if (var->IsType()) { + const phi::SparseCooTensor& tensor = var->Get(); + return tensor.dims(); + } else if (var->IsType()) { + const phi::SparseCsrTensor& tensor = var->Get(); + return tensor.dims(); } else { return DDim({-1}); } @@ -128,6 +134,18 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { } } else if (var->IsType()) { return "strings"; + } else if (var->IsType()) { + const phi::SparseCooTensor& tensor = var->Get(); + if (UNLIKELY(!tensor.initialized())) { + return ""; + } + return DataTypeToString(framework::TransToProtoVarType(tensor.dtype())); + } else if (var->IsType()) { + const phi::SparseCsrTensor& tensor = var->Get(); + if (UNLIKELY(!tensor.initialized())) { + return ""; + } + return DataTypeToString(framework::TransToProtoVarType(tensor.dtype())); } else { return ""; } diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index fdebffcc4f06c7..56e952623a1500 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -102,8 +102,7 @@ args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED) output : Tensor(out) infer_meta : - func : CastInferMeta - param: [x, value_dtype] + func : sparse::CastInferMeta kernel : func : cast_coo{sparse_coo -> sparse_coo}, cast_csr{sparse_csr -> sparse_csr} diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index 4c7c9ace49d321..f493e0249d7bf4 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -121,7 +121,7 @@ struct SparseTensorMeta { bool valid() const noexcept; DDim dims; - DataType dtype; + DataType dtype{DataType::UNDEFINED}; DataLayout layout{DataLayout::NCHW}; }; diff --git a/paddle/phi/infermeta/sparse/unary.cc b/paddle/phi/infermeta/sparse/unary.cc index f80f18bbba857a..01da3ae08eb742 100644 --- a/paddle/phi/infermeta/sparse/unary.cc +++ b/paddle/phi/infermeta/sparse/unary.cc @@ -36,5 +36,21 @@ void ValuesInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_layout(x.layout()); } +void CastInferMeta(const MetaTensor& x, + DataType index_dtype, + DataType out_dtype, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_layout(x.layout()); + out->share_lod(x); + // In inplace case, setting the dtype of out will reset the dtype of x at the + // same time, which will cause bugs, so move the dtype setting of out to the + // kernel + + if (!(out->is_same_tensor(x))) { + out->set_dtype(out_dtype); + } +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/infermeta/sparse/unary.h b/paddle/phi/infermeta/sparse/unary.h index 880e90b7ae697f..5ee7f054143c08 100644 --- a/paddle/phi/infermeta/sparse/unary.h +++ b/paddle/phi/infermeta/sparse/unary.h @@ -24,5 +24,10 @@ void IndicesInferMeta(const MetaTensor& x, MetaTensor* out); void ValuesInferMeta(const MetaTensor& x, MetaTensor* out); +void CastInferMeta(const MetaTensor& x, + DataType index_dtype, + DataType out_dtype, + MetaTensor* out); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu index 472777d7f35151..7ae8814470f41c 100644 --- a/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu @@ -132,7 +132,8 @@ PD_REGISTER_KERNEL(addmm_coo_dense, ALL_LAYOUT, phi::sparse::AddmmCooDenseKernel, float, - double) { + double, + phi::dtype::float16) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -141,6 +142,7 @@ PD_REGISTER_KERNEL(addmm_csr_dense, ALL_LAYOUT, phi::sparse::AddmmCsrDenseKernel, float, - double) { + double, + phi::dtype::float16) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } diff --git a/test/ir/inference/test_auto_mixed_precision_pass_for_sparse_op.py b/test/ir/inference/test_auto_mixed_precision_pass_for_sparse_op.py new file mode 100644 index 00000000000000..adb128c9863328 --- /dev/null +++ b/test/ir/inference/test_auto_mixed_precision_pass_for_sparse_op.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from inference_pass_test import InferencePassTest + +import paddle +from paddle.inference import Config, PrecisionType, create_predictor + + +class TestNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.sp_conv = paddle.sparse.nn.SubmConv2D( + 3, + 3, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + key=None, + ) + self.sp_bn = paddle.sparse.nn.BatchNorm( + 3, epsilon=1e-3, momentum=1 - 0.01, data_format='NHWC' + ) + self.relu = paddle.sparse.nn.ReLU() + + def forward(self, indices, values): + x = paddle.sparse.sparse_coo_tensor( + indices=indices, + values=values, + shape=[1, 32, 32, 3], + dtype='float32', + ) + x = self.sp_conv(x) + x = self.sp_bn(x) + x = self.relu(x) + return x.to_dense() + + +class AutoMixedPrecisionPassForSparseOp(InferencePassTest): + def setUp(self): + paddle.disable_static() + self.test_model = TestNet() + self.values = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]).astype( + 'float32' + ) + self.indices = np.array([[0, 0, 0], [0, 16, 16], [0, 20, 8]]).astype( + "int32" + ) + self.path_prefix = ( + "inference_test_models/auto_mixed_precision_pass_for_sparse_op_test" + ) + paddle.jit.save( + self.test_model, + self.path_prefix, + input_spec=[ + paddle.static.InputSpec( + shape=[3, -1], dtype='int32', name="indices" + ), + paddle.static.InputSpec( + shape=[-1, 3], dtype='float32', name="values" + ), + ], + ) + + def test_check_output(self): + fp32_out = self.inference("fp32") + fp16_out = self.inference("fp16") + np.testing.assert_allclose(fp32_out, fp16_out, rtol=1e-5, atol=1e-2) + + def inference(self, precision): + # Config + config = Config( + self.path_prefix + ".pdmodel", self.path_prefix + ".pdiparams" + ) + if precision == "fp16": + config.enable_use_gpu(100, 0, PrecisionType.Half) + white_list = ["sparse_batch_norm", "sparse_relu"] + config.exp_enable_mixed_precision_ops(set(white_list)) + else: + config.enable_use_gpu(100, 0, PrecisionType.Float32) + + # predictor + predictor = create_predictor(config) + + # inference + indices_tensor = predictor.get_input_handle("indices") + indices_tensor.reshape(self.indices.shape) + indices_tensor.copy_from_cpu(self.indices.copy()) + values_tensor = predictor.get_input_handle("values") + values_tensor.reshape(self.values.shape) + values_tensor.copy_from_cpu(self.values.copy()) + predictor.run() + output_tensor = predictor.get_output_handle( + predictor.get_output_names()[0] + ) + out = output_tensor.copy_to_cpu() + out = np.array(out).flatten() + return out + + +if __name__ == "__main__": + unittest.main()