Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 68 additions & 12 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
(type == VarType::VOCAB) || (type == VarType::SPARSE_COO) ||
(type == VarType::SPARSE_CSR);
}

inline bool IsFP32(VarType::Type type) { return type == VarType::FP32; }
Expand All @@ -123,12 +124,21 @@ void DoInsertCastOp(Graph* graph,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
const int out_dtype,
const VarType::Type t) {
if (t == VarType::SPARSE_COO || t == VarType::SPARSE_CSR) {
desc.SetType("sparse_cast");
desc.SetInput("x", {x_name});
desc.SetOutput("out", {out_name});
desc.SetAttr("index_dtype", -1);
desc.SetAttr("value_dtype", to_type);
} else {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
}
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
Expand All @@ -140,17 +150,21 @@ void DoInsertCastOp(Graph* graph,
std::string cast_output_name = var_node->Var()->Name() +
"_cast_auto_mixed.tmp_" +
std::to_string((*suffix)++);
VarType::Type var_type = var_node->Var()->GetType();
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
static_cast<int>(to_type),
var_type);
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetType(var_type);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
cast_output_vardesc->Flush();
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
Expand Down Expand Up @@ -452,16 +466,18 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
}
}

// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
// op's input var and output var only support
// dense/sparse_coo/sparse_csr tensor.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
if (real_in_var_node->Var()->Persistable()) continue;

support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR ||
real_in_var_node->Var()->GetType() == VarType::SPARSE_COO ||
real_in_var_node->Var()->GetType() == VarType::SPARSE_CSR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
Expand All @@ -470,7 +486,9 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {

support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR ||
real_out_var_node->Var()->GetType() == VarType::SPARSE_COO ||
real_out_var_node->Var()->GetType() == VarType::SPARSE_CSR);
}
}

Expand Down Expand Up @@ -634,6 +652,23 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") {
auto vecs = op_desc->Input("bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
Expand Down Expand Up @@ -728,6 +763,27 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") {
auto vecs = op_desc->Output("mean_out");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("variance_out");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("saved_mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("saved_variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("reserve_space");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}

if (backend_ == phi::Backend::XPU) {
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ static DDim GetDimsDebug(const Scope& scope,
}
} else if (var->IsType<Strings>()) {
return DDim({static_cast<int64_t>(var->Get<Strings>().size())});
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor& tensor = var->Get<phi::SparseCooTensor>();
return tensor.dims();
} else if (var->IsType<phi::SparseCsrTensor>()) {
const phi::SparseCsrTensor& tensor = var->Get<phi::SparseCsrTensor>();
return tensor.dims();
} else {
return DDim({-1});
}
Expand Down Expand Up @@ -128,6 +134,18 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
}
} else if (var->IsType<Strings>()) {
return "strings";
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor& tensor = var->Get<phi::SparseCooTensor>();
if (UNLIKELY(!tensor.initialized())) {
return "";
}
return DataTypeToString(framework::TransToProtoVarType(tensor.dtype()));
} else if (var->IsType<phi::SparseCsrTensor>()) {
const phi::SparseCsrTensor& tensor = var->Get<phi::SparseCsrTensor>();
if (UNLIKELY(!tensor.initialized())) {
return "";
}
return DataTypeToString(framework::TransToProtoVarType(tensor.dtype()));
} else {
return "";
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@
args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED)
output : Tensor(out)
infer_meta :
func : CastInferMeta
param: [x, value_dtype]
func : sparse::CastInferMeta
kernel :
func : cast_coo{sparse_coo -> sparse_coo},
cast_csr{sparse_csr -> sparse_csr}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/tensor_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct SparseTensorMeta {
bool valid() const noexcept;

DDim dims;
DataType dtype;
DataType dtype{DataType::UNDEFINED};
DataLayout layout{DataLayout::NCHW};
};

Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/infermeta/sparse/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,21 @@ void ValuesInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_layout(x.layout());
}

void CastInferMeta(const MetaTensor& x,
DataType index_dtype,
DataType out_dtype,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_layout(x.layout());
out->share_lod(x);
// In inplace case, setting the dtype of out will reset the dtype of x at the
// same time, which will cause bugs, so move the dtype setting of out to the
// kernel

if (!(out->is_same_tensor(x))) {
out->set_dtype(out_dtype);
}
}

} // namespace sparse
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/sparse/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,10 @@ void IndicesInferMeta(const MetaTensor& x, MetaTensor* out);

void ValuesInferMeta(const MetaTensor& x, MetaTensor* out);

void CastInferMeta(const MetaTensor& x,
DataType index_dtype,
DataType out_dtype,
MetaTensor* out);

} // namespace sparse
} // namespace phi
6 changes: 4 additions & 2 deletions paddle/phi/kernels/sparse/gpu/addmm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ PD_REGISTER_KERNEL(addmm_coo_dense,
ALL_LAYOUT,
phi::sparse::AddmmCooDenseKernel,
float,
double) {
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

Expand All @@ -141,6 +142,7 @@ PD_REGISTER_KERNEL(addmm_csr_dense,
ALL_LAYOUT,
phi::sparse::AddmmCsrDenseKernel,
float,
double) {
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
117 changes: 117 additions & 0 deletions test/ir/inference/test_auto_mixed_precision_pass_for_sparse_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from inference_pass_test import InferencePassTest

import paddle
from paddle.inference import Config, PrecisionType, create_predictor


class TestNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.sp_conv = paddle.sparse.nn.SubmConv2D(
3,
3,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
key=None,
)
self.sp_bn = paddle.sparse.nn.BatchNorm(
3, epsilon=1e-3, momentum=1 - 0.01, data_format='NHWC'
)
self.relu = paddle.sparse.nn.ReLU()

def forward(self, indices, values):
x = paddle.sparse.sparse_coo_tensor(
indices=indices,
values=values,
shape=[1, 32, 32, 3],
dtype='float32',
)
x = self.sp_conv(x)
x = self.sp_bn(x)
x = self.relu(x)
return x.to_dense()


class AutoMixedPrecisionPassForSparseOp(InferencePassTest):
def setUp(self):
paddle.disable_static()
self.test_model = TestNet()
self.values = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]).astype(
'float32'
)
self.indices = np.array([[0, 0, 0], [0, 16, 16], [0, 20, 8]]).astype(
"int32"
)
self.path_prefix = (
"inference_test_models/auto_mixed_precision_pass_for_sparse_op_test"
)
paddle.jit.save(
self.test_model,
self.path_prefix,
input_spec=[
paddle.static.InputSpec(
shape=[3, -1], dtype='int32', name="indices"
),
paddle.static.InputSpec(
shape=[-1, 3], dtype='float32', name="values"
),
],
)

def test_check_output(self):
fp32_out = self.inference("fp32")
fp16_out = self.inference("fp16")
np.testing.assert_allclose(fp32_out, fp16_out, rtol=1e-5, atol=1e-2)

def inference(self, precision):
# Config
config = Config(
self.path_prefix + ".pdmodel", self.path_prefix + ".pdiparams"
)
if precision == "fp16":
config.enable_use_gpu(100, 0, PrecisionType.Half)
white_list = ["sparse_batch_norm", "sparse_relu"]
config.exp_enable_mixed_precision_ops(set(white_list))
else:
config.enable_use_gpu(100, 0, PrecisionType.Float32)

# predictor
predictor = create_predictor(config)

# inference
indices_tensor = predictor.get_input_handle("indices")
indices_tensor.reshape(self.indices.shape)
indices_tensor.copy_from_cpu(self.indices.copy())
values_tensor = predictor.get_input_handle("values")
values_tensor.reshape(self.values.shape)
values_tensor.copy_from_cpu(self.values.copy())
predictor.run()
output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0]
)
out = output_tensor.copy_to_cpu()
out = np.array(out).flatten()
return out


if __name__ == "__main__":
unittest.main()