Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ void BindGraph(py::module *m) {
return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram,
return_value_policy::reference);
return_value_policy::reference)
.def("sub_graph_size", &Graph::SubGraphsSize)
.def("get_sub_graph", [](Graph &self, int i) {
/* Here we use a lambda function as an empty deleter to avoid the double
free of smart pointer.
Otherwise, this shared pointer will be free both in python and
cpp scope, which will lead a core dumped. */
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](Graph *) {}是为了防止外部析构?不能直接返回裸指针么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是防止double free,目的是让该指针在python析构,不然会在C++和python都析构一次。后面加了注释对这个的解释。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does return_value_policy::reference work for that?

});
}

void BindNode(py::module *m) {
Expand Down
26 changes: 23 additions & 3 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,6 +3956,23 @@ def all_op_nodes(self):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}

def all_sub_graphs(self, for_test=False):
"""
Return all sub_graphs included in the main graph as a set.
"""

return [
IrGraph(
self.graph.get_sub_graph(i), for_test=for_test)
for i in range(self.graph.sub_graph_size())
]

def get_sub_graph(self, i, for_test=False):
"""
Return i-th sub_graph in the main graph.
"""
return IrGraph(self.graph.get_sub_graph(i), for_test=for_test)

def create_persistable_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
Expand Down Expand Up @@ -4102,8 +4119,10 @@ def link_to(self, node_in, node_out):
node_in(IrNode): the input node.
node_out(IrNode): the output node.
"""
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.'
assert node_in.node in self.graph.nodes(), (
'node_in(%s) must be in the graph nodes.' % node_in.node.name())
assert node_out.node in self.graph.nodes(), (
'node_out(%s) must be in the graph nodes.' % node_out.node.name())
node_in.append_output(node_out)
node_out.append_input(node_in)

Expand Down Expand Up @@ -4265,7 +4284,8 @@ def _find_node_by_name(self, nodes, node_name):
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
assert target_node is not None, (
"Cannot find the target node (%s)in the giving set." % node_name)
return target_node

def _update_desc_attr(self, desc, name, val):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import paddle.fluid as fluid
import six

from paddle.fluid.framework import IrGraph
from paddle.fluid.framework import IrNode
from paddle.fluid.tests.unittests.op_test import OpTestTool
from paddle.fluid import core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard, default_startup_program
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass

paddle.enable_static()


class TestQuantizationSubGraph(unittest.TestCase):
def build_graph_with_sub_graph(self):
def linear_fc(num):
data = fluid.layers.data(
name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss

main_program = Program()
startup_program = Program()

def true_func():
return linear_fc(3)

def false_func():
return linear_fc(5)

with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
pred = layers.less_than(y, x)
out = layers.cond(pred, true_func, false_func)

core_graph = core.Graph(main_program.desc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得main_graph应该是传入而非函数内部创建,临时变量出作用域后就会被析构,python应该也不例外

# We should create graph for test, otherwise it will throw a
# error that it cannot find the node of "STEP_COUNTER"
graph = IrGraph(core_graph, for_test=True)
sub_graph = graph.get_sub_graph(0)
all_sub_graphs = graph.all_sub_graphs(
for_test=True) # same reason for subgraph
# Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will
# be destructed and the sub_graphs will be empty.
return graph, all_sub_graphs

def test_quant_sub_graphs(self, use_cuda=False):
graph, sub_graphs = self.build_graph_with_sub_graph()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type='abs_max',
weight_quantize_type='range_abs_max')
Find_inserted_quant_op = False
for sub_graph in sub_graphs:
transform_pass.apply(sub_graph)
for op in sub_graph.all_op_nodes():
if 'quantize' in op.name():
Find_inserted_quant_op = True
self.assertTrue(Find_inserted_quant_op)

def test_quant_sub_graphs_cpu(self):
self.test_quant_sub_graphs(use_cuda=False)

@OpTestTool.skip_if(not paddle.is_compiled_with_cuda(),
"Not GPU version paddle")
def test_quant_sub_graphs_gpu(self):
self.test_quant_sub_graphs(use_cuda=True)


if __name__ == '__main__':
unittest.main()