Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ class Graph {
PADDLE_ENFORCE_LT(
idx, sub_graphs_.size(),
platform::errors::InvalidArgument("Invalid sub_graph index"));
VLOG(3) << "sub_graph nodes num:"
<< sub_graphs_.at(idx).get()->Nodes().size();
return sub_graphs_.at(idx).get();
}

Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ void BindGraph(py::module *m) {
return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram,
return_value_policy::reference);
return_value_policy::reference)
.def("sub_graph_size", &Graph::SubGraphsSize)
.def("get_sub_graph", [](Graph &self, int i) {
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](Graph *) {}是为了防止外部析构?不能直接返回裸指针么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是防止double free,目的是让该指针在python析构,不然会在C++和python都析构一次。后面加了注释对这个的解释。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does return_value_policy::reference work for that?

});
}

void BindNode(py::module *m) {
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,6 +3956,21 @@ def all_op_nodes(self):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}

def all_sub_graphs(self):
"""
Return all sub_graphs included in the main graph as a set.
"""
return {
IrGraph(self.graph.get_sub_graph(i))
for i in range(self.graph.sub_graph_size())
}

def get_sub_graph(self, i):
"""
Return i-th sub_graph in the main graph.
"""
return IrGraph(self.graph.get_sub_graph(i))

def create_persistable_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import paddle.fluid as fluid
import six

from paddle.fluid.framework import IrGraph
from paddle.fluid.framework import IrNode
from paddle.fluid import core
from paddle.fluid.framework import Program, program_guard, default_startup_program
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass

paddle.enable_static()


class TestQuantizationSubGraph(unittest.TestCase):
def build_graph_with_sub_graph(self):
def linear_fc(num):
data = fluid.layers.data(
name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss

main_program = Program()
startup_program = Program()

def cond(i, ten):
return i < ten

def body(i, ten):
i = i + 1
return [i, ten]

main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(
shape=[1], fill_value=0, dtype='int64') # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int64') # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
# loss = linear_fc(3)
# opt = fluid.optimizer.Adam(learning_rate=0.001)
# opt.minimize(loss)
core_graph = core.Graph(main_program.desc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得main_graph应该是传入而非函数内部创建,临时变量出作用域后就会被析构,python应该也不例外

print("block_size = ", len(main_program.blocks))
print(core_graph)
graph = IrGraph(core_graph)
sub_graph = graph.get_sub_graph(0)
all_sub_graphs = graph.all_sub_graphs()
print(graph)
print("all_sub_graphs[0]", all_sub_graphs[0])
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=fluid.CUDAPlace(0),
activation_quantize_type='abs_max',
weight_quantize_type='range_abs_max')
print("sub_graph_size:", len(all_sub_graphs))
for sub_graph in all_sub_graphs:
print("sub_graph:", sub_graph) # subgraph_
nodes = sub_graph.all_op_nodes()
print("all nodes inside build graph func:",
nodes) #这里 all nodes 不为空,所以可以作量化Pass
transform_pass.apply(sub_graph)
print("Done quant pass applied ")
return all_sub_graphs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里返回后main_graph应该就被析构了,而sub_graphs其实只是main_graph的一个成员。


def test_quant_sub_graphs(self):
sub_graphs = self.build_graph_with_sub_graph()
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=fluid.CUDAPlace(0),
activation_quantize_type='abs_max',
weight_quantize_type='range_abs_max')
print("sub_graph_size:", len(sub_graphs))
for sub_graph in sub_graphs:
print("sub_graph:",
sub_graph) # 这里object地址都和返回前一样,但是拿到的all_op_nodes为空set()
nodes = sub_graph.all_op_nodes()
print("all nodes outside build graph func:", nodes)
transform_pass.apply(sub_graph) #因为拿不到 all_op_nodes 所以pass会失败
#program = sub_graph.to_program()
# print("sub_program",program)


if __name__ == '__main__':
unittest.main()