Skip to content

Commit a29ff4c

Browse files
authored
add python interface of sub_graph (#36120)
Add python interface of subgraph: 1. all_sub_graphs() 2. get_sub_graph(idx)
1 parent 1bd9cfe commit a29ff4c

File tree

3 files changed

+128
-4
lines changed

3 files changed

+128
-4
lines changed

paddle/fluid/pybind/ir.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,15 @@ void BindGraph(py::module *m) {
125125
return_value_policy::reference)
126126
.def("resolve_hazard", &Graph::ResolveHazard)
127127
.def("origin_program_desc", &Graph::OriginProgram,
128-
return_value_policy::reference);
128+
return_value_policy::reference)
129+
.def("sub_graph_size", &Graph::SubGraphsSize)
130+
.def("get_sub_graph", [](Graph &self, int i) {
131+
/* Here we use a lambda function as an empty deleter to avoid the double
132+
free of smart pointer.
133+
Otherwise, this shared pointer will be free both in python and
134+
cpp scope, which will lead a core dumped. */
135+
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
136+
});
129137
}
130138

131139
void BindNode(py::module *m) {

python/paddle/fluid/framework.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3956,6 +3956,23 @@ def all_op_nodes(self):
39563956
"""
39573957
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
39583958

3959+
def all_sub_graphs(self, for_test=False):
3960+
"""
3961+
Return all sub_graphs included in the main graph as a set.
3962+
"""
3963+
3964+
return [
3965+
IrGraph(
3966+
self.graph.get_sub_graph(i), for_test=for_test)
3967+
for i in range(self.graph.sub_graph_size())
3968+
]
3969+
3970+
def get_sub_graph(self, i, for_test=False):
3971+
"""
3972+
Return i-th sub_graph in the main graph.
3973+
"""
3974+
return IrGraph(self.graph.get_sub_graph(i), for_test=for_test)
3975+
39593976
def create_persistable_node(self, name, var_type, shape, var_dtype):
39603977
"""
39613978
Create a persistable variable node in the graph. In IrGraph,
@@ -4102,8 +4119,10 @@ def link_to(self, node_in, node_out):
41024119
node_in(IrNode): the input node.
41034120
node_out(IrNode): the output node.
41044121
"""
4105-
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
4106-
'The two arguments(node_in&node_out) must be in the graph nodes.'
4122+
assert node_in.node in self.graph.nodes(), (
4123+
'node_in(%s) must be in the graph nodes.' % node_in.node.name())
4124+
assert node_out.node in self.graph.nodes(), (
4125+
'node_out(%s) must be in the graph nodes.' % node_out.node.name())
41074126
node_in.append_output(node_out)
41084127
node_out.append_input(node_in)
41094128

@@ -4265,7 +4284,8 @@ def _find_node_by_name(self, nodes, node_name):
42654284
for n in nodes:
42664285
if n.name() == node_name:
42674286
target_node = n
4268-
assert target_node is not None, "Cannot find the target node in the giving set."
4287+
assert target_node is not None, (
4288+
"Cannot find the target node (%s)in the giving set." % node_name)
42694289
return target_node
42704290

42714291
def _update_desc_attr(self, desc, name, val):
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import paddle.fluid as fluid
18+
import six
19+
20+
from paddle.fluid.framework import IrGraph
21+
from paddle.fluid.framework import IrNode
22+
from paddle.fluid.tests.unittests.op_test import OpTestTool
23+
from paddle.fluid import core
24+
import paddle.fluid.layers as layers
25+
from paddle.fluid.framework import Program, program_guard, default_startup_program
26+
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
27+
28+
paddle.enable_static()
29+
30+
31+
class TestQuantizationSubGraph(unittest.TestCase):
32+
def build_graph_with_sub_graph(self):
33+
def linear_fc(num):
34+
data = fluid.layers.data(
35+
name='image', shape=[1, 32, 32], dtype='float32')
36+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
37+
hidden = data
38+
for _ in six.moves.xrange(num):
39+
hidden = fluid.layers.fc(hidden, size=128, act='relu')
40+
loss = fluid.layers.cross_entropy(input=hidden, label=label)
41+
loss = fluid.layers.mean(loss)
42+
return loss
43+
44+
main_program = Program()
45+
startup_program = Program()
46+
47+
def true_func():
48+
return linear_fc(3)
49+
50+
def false_func():
51+
return linear_fc(5)
52+
53+
with program_guard(main_program, startup_program):
54+
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
55+
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
56+
pred = layers.less_than(y, x)
57+
out = layers.cond(pred, true_func, false_func)
58+
59+
core_graph = core.Graph(main_program.desc)
60+
# We should create graph for test, otherwise it will throw a
61+
# error that it cannot find the node of "STEP_COUNTER"
62+
graph = IrGraph(core_graph, for_test=True)
63+
sub_graph = graph.get_sub_graph(0)
64+
all_sub_graphs = graph.all_sub_graphs(
65+
for_test=True) # same reason for subgraph
66+
# Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will
67+
# be destructed and the sub_graphs will be empty.
68+
return graph, all_sub_graphs
69+
70+
def test_quant_sub_graphs(self, use_cuda=False):
71+
graph, sub_graphs = self.build_graph_with_sub_graph()
72+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
73+
transform_pass = QuantizationTransformPass(
74+
scope=fluid.global_scope(),
75+
place=place,
76+
activation_quantize_type='abs_max',
77+
weight_quantize_type='range_abs_max')
78+
Find_inserted_quant_op = False
79+
for sub_graph in sub_graphs:
80+
transform_pass.apply(sub_graph)
81+
for op in sub_graph.all_op_nodes():
82+
if 'quantize' in op.name():
83+
Find_inserted_quant_op = True
84+
self.assertTrue(Find_inserted_quant_op)
85+
86+
def test_quant_sub_graphs_cpu(self):
87+
self.test_quant_sub_graphs(use_cuda=False)
88+
89+
@OpTestTool.skip_if(not paddle.is_compiled_with_cuda(),
90+
"Not GPU version paddle")
91+
def test_quant_sub_graphs_gpu(self):
92+
self.test_quant_sub_graphs(use_cuda=True)
93+
94+
95+
if __name__ == '__main__':
96+
unittest.main()

0 commit comments

Comments
 (0)