From 4167ec4bc62d7e20d3afa7402d9fc31c697defc7 Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Sun, 26 Sep 2021 12:30:00 +0000 Subject: [PATCH 1/6] add python interface of sub_graph --- paddle/fluid/pybind/ir.cc | 6 +- python/paddle/fluid/framework.py | 15 ++++ .../ir/test_ir_subgraph_python_interface.py | 89 +++++++++++++++++++ 3 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index e27e3674eeeb5b..b3c01f0e2db853 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -125,7 +125,11 @@ void BindGraph(py::module *m) { return_value_policy::reference) .def("resolve_hazard", &Graph::ResolveHazard) .def("origin_program_desc", &Graph::OriginProgram, - return_value_policy::reference); + return_value_policy::reference) + .def("sub_graph_size", &Graph::SubGraphsSize) + .def("get_sub_graph", [](Graph &self, int i) { + return std::shared_ptr(self.GetSubGraph(i), [](Graph *) {}); + }); } void BindNode(py::module *m) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 11e7e7c2f7c08c..c2e7c389ea5415 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3956,6 +3956,21 @@ def all_op_nodes(self): """ return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} + def all_sub_graphs(self): + """ + Return all sub_graphs included in the main graph as a set. + """ + return { + IrGraph(self.graph.get_sub_graph(i)) + for i in range(self.graph.sub_graph_size()) + } + + def get_sub_graph(self, i): + """ + Return i-th sub_graph in the main graph. + """ + return IrGraph(self.graph.get_sub_graph(i)) + def create_persistable_node(self, name, var_type, shape, var_dtype): """ Create a persistable variable node in the graph. In IrGraph, diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py new file mode 100644 index 00000000000000..35552a0ddd1785 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.fluid as fluid +import six + +from paddle.fluid.framework import IrGraph +from paddle.fluid.framework import IrNode +from paddle.fluid import core +from paddle.fluid.framework import Program, program_guard, default_startup_program +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass + +paddle.enable_static() + + +class TestQuantizationSubGraph(unittest.TestCase): + def build_graph_with_sub_graph(self): + def linear_fc(num): + data = fluid.layers.data( + name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + hidden = fluid.layers.fc(hidden, size=128, act='relu') + loss = fluid.layers.cross_entropy(input=hidden, label=label) + loss = fluid.layers.mean(loss) + return loss + + main_program = Program() + startup_program = Program() + + def cond(i, ten): + return i < ten + + def body(i, ten): + i = i + 1 + return [i, ten] + + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + with paddle.static.program_guard(main_program, startup_program): + # i = paddle.full(shape=[1], fill_value=0, dtype='int64') # loop counter + # ten = paddle.full(shape=[1], fill_value=10, dtype='int64') # loop length + # i, ten = paddle.static.nn.while_loop(cond, body, [i, ten]) + loss = linear_fc(3) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + core_graph = core.Graph(main_program.desc) + print("block_size = ", len(main_program.blocks)) + print(core_graph) + graph = IrGraph(core_graph) + sub_graph = graph.get_sub_graph(0) + all_sub_graphs = graph.all_sub_graphs() + print(graph) + # print(sub_graph.all_nodes()) + + #return [graph] + return all_sub_graphs + + def test_quant_sub_graphs(self): + sub_graphs = self.build_graph_with_sub_graph() + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + place=fluid.CUDAPlace(0), + activation_quantize_type='abs_max', + weight_quantize_type='range_abs_max') + + for sub_graph in sub_graphs: + print("sub_graph:", sub_graph) + transform_pass.apply(sub_graph) + # program = sub_graph.to_program() + # print("sub_program",program) + + +if __name__ == '__main__': + unittest.main() From 7c029ddf7655ab07209108ad86095de9059c7fff Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Mon, 27 Sep 2021 09:20:39 +0000 Subject: [PATCH 2/6] for Debug --- paddle/fluid/framework/ir/graph.h | 2 + .../ir/test_ir_subgraph_python_interface.py | 44 +++++++++++++------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 21e743e3587d80..3198a385a004f9 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -371,6 +371,8 @@ class Graph { PADDLE_ENFORCE_LT( idx, sub_graphs_.size(), platform::errors::InvalidArgument("Invalid sub_graph index")); + VLOG(3) << "sub_graph nodes num:" + << sub_graphs_.at(idx).get()->Nodes().size(); return sub_graphs_.at(idx).get(); } diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py index 35552a0ddd1785..78d23c54546d10 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py @@ -52,12 +52,14 @@ def body(i, ten): main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): - # i = paddle.full(shape=[1], fill_value=0, dtype='int64') # loop counter - # ten = paddle.full(shape=[1], fill_value=10, dtype='int64') # loop length - # i, ten = paddle.static.nn.while_loop(cond, body, [i, ten]) - loss = linear_fc(3) - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) + i = paddle.full( + shape=[1], fill_value=0, dtype='int64') # loop counter + ten = paddle.full( + shape=[1], fill_value=10, dtype='int64') # loop length + i, ten = paddle.static.nn.while_loop(cond, body, [i, ten]) + # loss = linear_fc(3) + # opt = fluid.optimizer.Adam(learning_rate=0.001) + # opt.minimize(loss) core_graph = core.Graph(main_program.desc) print("block_size = ", len(main_program.blocks)) print(core_graph) @@ -65,9 +67,20 @@ def body(i, ten): sub_graph = graph.get_sub_graph(0) all_sub_graphs = graph.all_sub_graphs() print(graph) - # print(sub_graph.all_nodes()) - - #return [graph] + print("all_sub_graphs[0]", all_sub_graphs[0]) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + place=fluid.CUDAPlace(0), + activation_quantize_type='abs_max', + weight_quantize_type='range_abs_max') + print("sub_graph_size:", len(all_sub_graphs)) + for sub_graph in all_sub_graphs: + print("sub_graph:", sub_graph) # subgraph_ + nodes = sub_graph.all_op_nodes() + print("all nodes inside build graph func:", + nodes) #这里 all nodes 不为空,所以可以作量化Pass + transform_pass.apply(sub_graph) + print("Done quant pass applied ") return all_sub_graphs def test_quant_sub_graphs(self): @@ -77,12 +90,15 @@ def test_quant_sub_graphs(self): place=fluid.CUDAPlace(0), activation_quantize_type='abs_max', weight_quantize_type='range_abs_max') - + print("sub_graph_size:", len(sub_graphs)) for sub_graph in sub_graphs: - print("sub_graph:", sub_graph) - transform_pass.apply(sub_graph) - # program = sub_graph.to_program() - # print("sub_program",program) + print("sub_graph:", + sub_graph) # 这里object地址都和返回前一样,但是拿到的all_op_nodes为空set() + nodes = sub_graph.all_op_nodes() + print("all nodes outside build graph func:", nodes) + transform_pass.apply(sub_graph) #因为拿不到 all_op_nodes 所以pass会失败 + #program = sub_graph.to_program() + # print("sub_program",program) if __name__ == '__main__': From 088225df243d90f3079e207b0e8d2282d38062b7 Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Mon, 27 Sep 2021 13:40:29 +0000 Subject: [PATCH 3/6] update Debug info --- .../ir/test_ir_subgraph_python_interface.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py index 78d23c54546d10..3c3c4e64237f65 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py @@ -75,26 +75,31 @@ def body(i, ten): weight_quantize_type='range_abs_max') print("sub_graph_size:", len(all_sub_graphs)) for sub_graph in all_sub_graphs: - print("sub_graph:", sub_graph) # subgraph_ - nodes = sub_graph.all_op_nodes() + print("sub_graph_dict:", sub_graph.__dict__) # subgraph_ + op_nodes = sub_graph.all_op_nodes() + var_nodes = sub_graph.all_var_nodes() + print("inside func:ir graph's core graph", sub_graph.graph) print("all nodes inside build graph func:", - nodes) #这里 all nodes 不为空,所以可以作量化Pass + op_nodes) #这里 all nodes 不为空,所以可以作量化Pass + print("all var_nodes inside build graph func:", var_nodes) transform_pass.apply(sub_graph) + print("Done quant pass applied ") return all_sub_graphs - + #return [graph] def test_quant_sub_graphs(self): sub_graphs = self.build_graph_with_sub_graph() + print("num of sub graphs:", len(sub_graphs)) transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=fluid.CUDAPlace(0), activation_quantize_type='abs_max', weight_quantize_type='range_abs_max') - print("sub_graph_size:", len(sub_graphs)) for sub_graph in sub_graphs: - print("sub_graph:", - sub_graph) # 这里object地址都和返回前一样,但是拿到的all_op_nodes为空set() - nodes = sub_graph.all_op_nodes() + print("sub_graph_dict_outside:", sub_graph. + __dict__) # 这里object地址都和返回前一样,但是拿到的all_op_nodes为空set() + nodes = sub_graph.all_nodes() + print("ir graph's core graph", sub_graph.graph) print("all nodes outside build graph func:", nodes) transform_pass.apply(sub_graph) #因为拿不到 all_op_nodes 所以pass会失败 #program = sub_graph.to_program() From e65daeb57096602d0524d4fce26471f7d8098d7f Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Tue, 28 Sep 2021 12:42:16 +0000 Subject: [PATCH 4/6] fix a bug in subgraph irGraph, add the arguement of for_test in subgraph --- paddle/fluid/pybind/ir.cc | 5 ++ python/paddle/fluid/framework.py | 23 +++--- .../ir/test_ir_subgraph_python_interface.py | 76 +++++++------------ 3 files changed, 45 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b3c01f0e2db853..25dc3163bf211b 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -128,6 +128,11 @@ void BindGraph(py::module *m) { return_value_policy::reference) .def("sub_graph_size", &Graph::SubGraphsSize) .def("get_sub_graph", [](Graph &self, int i) { + /* Here we use a lambda function as an empty deleter to avoid the double + free of smart pointer. + Otherwise, this shared pointer will free itself both in python and cpp + scope, which will lead + a core dumped. */ return std::shared_ptr(self.GetSubGraph(i), [](Graph *) {}); }); } diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c2e7c389ea5415..07a2e763e383a2 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3956,20 +3956,22 @@ def all_op_nodes(self): """ return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} - def all_sub_graphs(self): + def all_sub_graphs(self, for_test=False): """ Return all sub_graphs included in the main graph as a set. """ - return { - IrGraph(self.graph.get_sub_graph(i)) + + return [ + IrGraph( + self.graph.get_sub_graph(i), for_test=for_test) for i in range(self.graph.sub_graph_size()) - } + ] - def get_sub_graph(self, i): + def get_sub_graph(self, i, for_test=False): """ Return i-th sub_graph in the main graph. """ - return IrGraph(self.graph.get_sub_graph(i)) + return IrGraph(self.graph.get_sub_graph(i), for_test=for_test) def create_persistable_node(self, name, var_type, shape, var_dtype): """ @@ -4117,8 +4119,10 @@ def link_to(self, node_in, node_out): node_in(IrNode): the input node. node_out(IrNode): the output node. """ - assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \ - 'The two arguments(node_in&node_out) must be in the graph nodes.' + assert node_in.node in self.graph.nodes(), ( + 'node_in(%s) must be in the graph nodes.' % node_in.node.name()) + assert node_out.node in self.graph.nodes(), ( + 'node_out(%s) must be in the graph nodes.' % node_out.node.name()) node_in.append_output(node_out) node_out.append_input(node_in) @@ -4280,7 +4284,8 @@ def _find_node_by_name(self, nodes, node_name): for n in nodes: if n.name() == node_name: target_node = n - assert target_node is not None, "Cannot find the target node in the giving set." + assert target_node is not None, ( + "Cannot find the target node (%s)in the giving set." % node_name) return target_node def _update_desc_attr(self, desc, name, val): diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py index 3c3c4e64237f65..9046bb713b90d1 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py @@ -20,6 +20,7 @@ from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrNode from paddle.fluid import core +import paddle.fluid.layers as layers from paddle.fluid.framework import Program, program_guard, default_startup_program from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass @@ -42,68 +43,43 @@ def linear_fc(num): main_program = Program() startup_program = Program() - def cond(i, ten): - return i < ten + def true_func(): + return linear_fc(3) - def body(i, ten): - i = i + 1 - return [i, ten] + def false_func(): + return linear_fc(5) + + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.23) + pred = layers.less_than(y, x) + out = layers.cond(pred, true_func, false_func) - main_program = paddle.static.default_main_program() - startup_program = paddle.static.default_startup_program() - with paddle.static.program_guard(main_program, startup_program): - i = paddle.full( - shape=[1], fill_value=0, dtype='int64') # loop counter - ten = paddle.full( - shape=[1], fill_value=10, dtype='int64') # loop length - i, ten = paddle.static.nn.while_loop(cond, body, [i, ten]) - # loss = linear_fc(3) - # opt = fluid.optimizer.Adam(learning_rate=0.001) - # opt.minimize(loss) core_graph = core.Graph(main_program.desc) - print("block_size = ", len(main_program.blocks)) - print(core_graph) - graph = IrGraph(core_graph) + # We should create graph for test, otherwise it will throw a + # error that it cannot find the node of "STEP_COUNTER" + graph = IrGraph(core_graph, for_test=True) sub_graph = graph.get_sub_graph(0) - all_sub_graphs = graph.all_sub_graphs() - print(graph) - print("all_sub_graphs[0]", all_sub_graphs[0]) - transform_pass = QuantizationTransformPass( - scope=fluid.global_scope(), - place=fluid.CUDAPlace(0), - activation_quantize_type='abs_max', - weight_quantize_type='range_abs_max') - print("sub_graph_size:", len(all_sub_graphs)) - for sub_graph in all_sub_graphs: - print("sub_graph_dict:", sub_graph.__dict__) # subgraph_ - op_nodes = sub_graph.all_op_nodes() - var_nodes = sub_graph.all_var_nodes() - print("inside func:ir graph's core graph", sub_graph.graph) - print("all nodes inside build graph func:", - op_nodes) #这里 all nodes 不为空,所以可以作量化Pass - print("all var_nodes inside build graph func:", var_nodes) - transform_pass.apply(sub_graph) + all_sub_graphs = graph.all_sub_graphs( + for_test=True) # same reason for subgraph + # Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will + # be destructed and the sub_graphs will be empty. + return graph, all_sub_graphs - print("Done quant pass applied ") - return all_sub_graphs - #return [graph] def test_quant_sub_graphs(self): - sub_graphs = self.build_graph_with_sub_graph() - print("num of sub graphs:", len(sub_graphs)) + graph, sub_graphs = self.build_graph_with_sub_graph() transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=fluid.CUDAPlace(0), activation_quantize_type='abs_max', weight_quantize_type='range_abs_max') + Find_inserted_quant_op = False for sub_graph in sub_graphs: - print("sub_graph_dict_outside:", sub_graph. - __dict__) # 这里object地址都和返回前一样,但是拿到的all_op_nodes为空set() - nodes = sub_graph.all_nodes() - print("ir graph's core graph", sub_graph.graph) - print("all nodes outside build graph func:", nodes) - transform_pass.apply(sub_graph) #因为拿不到 all_op_nodes 所以pass会失败 - #program = sub_graph.to_program() - # print("sub_program",program) + transform_pass.apply(sub_graph) + for op in sub_graph.all_op_nodes(): + if 'quantize' in op.name(): + Find_inserted_quant_op = True + self.assertTrue(Find_inserted_quant_op) if __name__ == '__main__': From 69b73e63e9d9a55e13fa100cf2100e79df0a4f3e Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Tue, 28 Sep 2021 12:44:22 +0000 Subject: [PATCH 5/6] remove some Debug info --- paddle/fluid/framework/ir/graph.h | 2 -- paddle/fluid/pybind/ir.cc | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 3198a385a004f9..21e743e3587d80 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -371,8 +371,6 @@ class Graph { PADDLE_ENFORCE_LT( idx, sub_graphs_.size(), platform::errors::InvalidArgument("Invalid sub_graph index")); - VLOG(3) << "sub_graph nodes num:" - << sub_graphs_.at(idx).get()->Nodes().size(); return sub_graphs_.at(idx).get(); } diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 25dc3163bf211b..050bfc967daa10 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -130,9 +130,8 @@ void BindGraph(py::module *m) { .def("get_sub_graph", [](Graph &self, int i) { /* Here we use a lambda function as an empty deleter to avoid the double free of smart pointer. - Otherwise, this shared pointer will free itself both in python and cpp - scope, which will lead - a core dumped. */ + Otherwise, this shared pointer will be free both in python and + cpp scope, which will lead a core dumped. */ return std::shared_ptr(self.GetSubGraph(i), [](Graph *) {}); }); } From b2a088eadc2660ee4f8d94bf7c7c1a119fac7601 Mon Sep 17 00:00:00 2001 From: huangxu96 Date: Wed, 29 Sep 2021 03:14:22 +0000 Subject: [PATCH 6/6] add CPU test case for Windows CI --- .../ir/test_ir_subgraph_python_interface.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py index 9046bb713b90d1..49ca89a35f4ac7 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_subgraph_python_interface.py @@ -19,6 +19,7 @@ from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrNode +from paddle.fluid.tests.unittests.op_test import OpTestTool from paddle.fluid import core import paddle.fluid.layers as layers from paddle.fluid.framework import Program, program_guard, default_startup_program @@ -66,11 +67,12 @@ def false_func(): # be destructed and the sub_graphs will be empty. return graph, all_sub_graphs - def test_quant_sub_graphs(self): + def test_quant_sub_graphs(self, use_cuda=False): graph, sub_graphs = self.build_graph_with_sub_graph() + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), - place=fluid.CUDAPlace(0), + place=place, activation_quantize_type='abs_max', weight_quantize_type='range_abs_max') Find_inserted_quant_op = False @@ -81,6 +83,14 @@ def test_quant_sub_graphs(self): Find_inserted_quant_op = True self.assertTrue(Find_inserted_quant_op) + def test_quant_sub_graphs_cpu(self): + self.test_quant_sub_graphs(use_cuda=False) + + @OpTestTool.skip_if(not paddle.is_compiled_with_cuda(), + "Not GPU version paddle") + def test_quant_sub_graphs_gpu(self): + self.test_quant_sub_graphs(use_cuda=True) + if __name__ == '__main__': unittest.main()