Skip to content

Commit 29e540a

Browse files
authored
add python interface of sub_graph (#36120) (#38235)
Add python interface of subgraph: 1. all_sub_graphs() 2. get_sub_graph(idx)
1 parent d70a06c commit 29e540a

File tree

3 files changed

+128
-4
lines changed

3 files changed

+128
-4
lines changed

paddle/fluid/pybind/ir.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,15 @@ void BindGraph(py::module *m) {
125125
return_value_policy::reference)
126126
.def("resolve_hazard", &Graph::ResolveHazard)
127127
.def("origin_program_desc", &Graph::OriginProgram,
128-
return_value_policy::reference);
128+
return_value_policy::reference)
129+
.def("sub_graph_size", &Graph::SubGraphsSize)
130+
.def("get_sub_graph", [](Graph &self, int i) {
131+
/* Here we use a lambda function as an empty deleter to avoid the double
132+
free of smart pointer.
133+
Otherwise, this shared pointer will be free both in python and
134+
cpp scope, which will lead a core dumped. */
135+
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
136+
});
129137
}
130138

131139
void BindNode(py::module *m) {

python/paddle/fluid/framework.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3960,6 +3960,23 @@ def all_op_nodes(self):
39603960
"""
39613961
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
39623962

3963+
def all_sub_graphs(self, for_test=False):
3964+
"""
3965+
Return all sub_graphs included in the main graph as a set.
3966+
"""
3967+
3968+
return [
3969+
IrGraph(
3970+
self.graph.get_sub_graph(i), for_test=for_test)
3971+
for i in range(self.graph.sub_graph_size())
3972+
]
3973+
3974+
def get_sub_graph(self, i, for_test=False):
3975+
"""
3976+
Return i-th sub_graph in the main graph.
3977+
"""
3978+
return IrGraph(self.graph.get_sub_graph(i), for_test=for_test)
3979+
39633980
def create_persistable_node(self, name, var_type, shape, var_dtype):
39643981
"""
39653982
Create a persistable variable node in the graph. In IrGraph,
@@ -4106,8 +4123,10 @@ def link_to(self, node_in, node_out):
41064123
node_in(IrNode): the input node.
41074124
node_out(IrNode): the output node.
41084125
"""
4109-
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
4110-
'The two arguments(node_in&node_out) must be in the graph nodes.'
4126+
assert node_in.node in self.graph.nodes(), (
4127+
'node_in(%s) must be in the graph nodes.' % node_in.node.name())
4128+
assert node_out.node in self.graph.nodes(), (
4129+
'node_out(%s) must be in the graph nodes.' % node_out.node.name())
41114130
node_in.append_output(node_out)
41124131
node_out.append_input(node_in)
41134132

@@ -4269,7 +4288,8 @@ def _find_node_by_name(self, nodes, node_name):
42694288
for n in nodes:
42704289
if n.name() == node_name:
42714290
target_node = n
4272-
assert target_node is not None, "Cannot find the target node in the giving set."
4291+
assert target_node is not None, (
4292+
"Cannot find the target node (%s)in the giving set." % node_name)
42734293
return target_node
42744294

42754295
def _update_desc_attr(self, desc, name, val):
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import paddle.fluid as fluid
18+
import six
19+
20+
from paddle.fluid.framework import IrGraph
21+
from paddle.fluid.framework import IrNode
22+
from paddle.fluid.tests.unittests.op_test import OpTestTool
23+
from paddle.fluid import core
24+
import paddle.fluid.layers as layers
25+
from paddle.fluid.framework import Program, program_guard, default_startup_program
26+
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
27+
28+
paddle.enable_static()
29+
30+
31+
class TestQuantizationSubGraph(unittest.TestCase):
32+
def build_graph_with_sub_graph(self):
33+
def linear_fc(num):
34+
data = fluid.layers.data(
35+
name='image', shape=[1, 32, 32], dtype='float32')
36+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
37+
hidden = data
38+
for _ in six.moves.xrange(num):
39+
hidden = fluid.layers.fc(hidden, size=128, act='relu')
40+
loss = fluid.layers.cross_entropy(input=hidden, label=label)
41+
loss = fluid.layers.mean(loss)
42+
return loss
43+
44+
main_program = Program()
45+
startup_program = Program()
46+
47+
def true_func():
48+
return linear_fc(3)
49+
50+
def false_func():
51+
return linear_fc(5)
52+
53+
with program_guard(main_program, startup_program):
54+
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
55+
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
56+
pred = layers.less_than(y, x)
57+
out = layers.cond(pred, true_func, false_func)
58+
59+
core_graph = core.Graph(main_program.desc)
60+
# We should create graph for test, otherwise it will throw a
61+
# error that it cannot find the node of "STEP_COUNTER"
62+
graph = IrGraph(core_graph, for_test=True)
63+
sub_graph = graph.get_sub_graph(0)
64+
all_sub_graphs = graph.all_sub_graphs(
65+
for_test=True) # same reason for subgraph
66+
# Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will
67+
# be destructed and the sub_graphs will be empty.
68+
return graph, all_sub_graphs
69+
70+
def test_quant_sub_graphs(self, use_cuda=False):
71+
graph, sub_graphs = self.build_graph_with_sub_graph()
72+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
73+
transform_pass = QuantizationTransformPass(
74+
scope=fluid.global_scope(),
75+
place=place,
76+
activation_quantize_type='abs_max',
77+
weight_quantize_type='range_abs_max')
78+
Find_inserted_quant_op = False
79+
for sub_graph in sub_graphs:
80+
transform_pass.apply(sub_graph)
81+
for op in sub_graph.all_op_nodes():
82+
if 'quantize' in op.name():
83+
Find_inserted_quant_op = True
84+
self.assertTrue(Find_inserted_quant_op)
85+
86+
def test_quant_sub_graphs_cpu(self):
87+
self.test_quant_sub_graphs(use_cuda=False)
88+
89+
@OpTestTool.skip_if(not paddle.is_compiled_with_cuda(),
90+
"Not GPU version paddle")
91+
def test_quant_sub_graphs_gpu(self):
92+
self.test_quant_sub_graphs(use_cuda=True)
93+
94+
95+
if __name__ == '__main__':
96+
unittest.main()

0 commit comments

Comments
 (0)