Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ProgramWriter {

Json WriteProgram(const pir::Program* program);
Json WriteRegion(const pir::Region* region, const std::string& region_name);
Json WriteBlock(const pir::Block* block, const std::string& block_name);
Json WriteBlock(pir::Block* block, const std::string& block_name);
Json WriteOp(const pir::Operation& op);
Json WriteBlockArg(const pir::Value& value);
Json WriteValue(const pir::Value& value);
Expand Down
36 changes: 32 additions & 4 deletions paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h"
#include "paddle/pir/include/core/dialect.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

namespace pir {

Expand Down Expand Up @@ -55,11 +57,11 @@ Json ProgramWriter::WriteRegion(const pir::Region* region,
return region_json;
}

Json ProgramWriter::WriteBlock(const pir::Block* block,
Json ProgramWriter::WriteBlock(pir::Block* block,
const std::string& block_name) {
Json block_json;
block_json[ID] = block_name;

VLOG(4) << "Begin write " << block_name << ".";
Json args_json = Json::array();
for (auto arg : block->args()) {
auto arg_json = WriteBlockArg(arg);
Expand All @@ -81,13 +83,38 @@ Json ProgramWriter::WriteBlock(const pir::Block* block,
}

Json ops_json = Json::array();

/* delete cf.stack_create / cf.tuple_push */
std::vector<pir::Operation*> delete_ops;
for (auto op : block->ops()) {
if (op->isa<pir::StackCreateOp>()) {
delete_ops.push_back(op);
}
}
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
for (auto op : delete_ops) {
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
if (stack_op.inlet().HasOneUse()) {
auto tuple_push_op = stack_op.tuple_push_op();
auto block_in = tuple_push_op->GetParent();
block_in->erase(*tuple_push_op);
}
if (stack_op.outlet().HasOneUse()) {
auto tuple_pop_op = stack_op.tuple_pop_op();
auto block_in = tuple_pop_op->GetParent();
block_in->erase(*tuple_pop_op);
}
block->erase(*op);
}
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
for (auto op : block->ops()) {
auto op_json = WriteOp(*op);
ops_json.emplace_back(op_json);
}
block_json[BLOCKOPS] = ops_json;

VLOG(6) << "Finish write " << block_name << ".";
VLOG(4) << "Finish write " << block_name << ".";
return block_json;
}

Expand Down Expand Up @@ -126,6 +153,7 @@ Json ProgramWriter::WriteOp(const pir::Operation& op) {
Json op_json = Json::object();
op_json[ID] = op.name();
// serialize opoperands
VLOG(4) << "Begin write Operation " << op.name() << ".";
Json operands_json = Json::array();
for (auto operand : op.operands()) {
auto operand_json = WriteOpOperand(operand);
Expand Down Expand Up @@ -159,7 +187,7 @@ Json ProgramWriter::WriteOp(const pir::Operation& op) {
op_json[OPRESULTS_ATTRS] = WriteAttributesMapOther(op.attributes());
}

VLOG(6) << "Finish write Operation " << op.name() << ".";
VLOG(4) << "Finish write Operation " << op.name() << ".";
return op_json;
}

Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <unordered_set>
#include <vector>

#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h"
Expand Down Expand Up @@ -327,6 +328,36 @@ std::vector<Value> PyWhileOp::OptimizeUpdate() {
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
}
for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num;
++operand_index, ++arg_index) {
if (!body_block.arg(arg_index).type().isa<pir::DenseTensorType>()) {
continue;
}

auto l_type =
body_block.arg(arg_index).type().dyn_cast<pir::DenseTensorType>();
auto r_type = yield_op.operand_source(operand_index)
.type()
.dyn_cast<pir::DenseTensorType>();
if (l_type.dims().size() == r_type.dims().size() &&
l_type.dims() != r_type.dims()) {
VLOG(4) << "while op input " << operand_index
<< " has dynamic shape, origin shape is: " << l_type.dims()
<< "new shape is: " << r_type.dims();
auto dim = common::ComputeCompatibleDim(l_type.dims(), r_type.dims());
auto new_type = pir::DenseTensorType::get(operation_->ir_context(),
l_type.dtype(),
dim,
l_type.data_layout(),
l_type.lod(),
l_type.offset());
body_block.arg(arg_index).set_type(new_type);
yield_op.operand_source(operand_index).set_type(new_type);
result(arg_index).set_type(new_type);
VLOG(4) << "change shape as: " << new_type.dims();
}
}

for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num;
++operand_index) {
if (yield_op.operand_source(operand_index) == body_block.arg(arg_index)) {
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/jit/dy2static/function_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.distributed.auto_parallel.placement_type import (
to_placements,
)
from paddle.jit.pir_translated_layer import PirTranslatedLayer
from paddle.jit.translated_layer import TranslatedLayer
from paddle.nn.layer import layers

Expand Down Expand Up @@ -58,7 +59,8 @@ def __init__(self, function, input_spec=None):
# parse *args
self.varargs_name = parse_varargs_name(function)
if self.varargs_name is not None and isinstance(
getattr(function, '__self__', None), TranslatedLayer
getattr(function, '__self__', None),
(TranslatedLayer, PirTranslatedLayer),
):
self._arg_names += function.__self__._input_args_names

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
import tempfile
import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle.pir_utils import test_with_dygraph_pir


def getModelOp(model_path):
Expand All @@ -28,25 +31,33 @@ def getModelOp(model_path):

result = set()
for i in range(0, size):
# print(main_block.op(i).type())
result.add(main_block.op(i).type())

return result


def GetPirModelOp(model_path):
recover_program = paddle.static.Program()
paddle.base.core.deserialize_pir_program(
model_path, recover_program, 1 # pir_version
)

return recover_program


class WhileNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
y = paddle.rand(shape=[1, 3, 4, 4])

w1 = paddle.shape(y)[0]
w2 = paddle.assign(paddle.shape(x)[0])
w1 = paddle.shape(y)[2]
w2 = paddle.assign(paddle.shape(x)[2])

while w2 != w1:
x = F.avg_pool2d(x, kernel_size=3, padding=1, stride=2)
w2 = paddle.shape(x)[0]
w2 = paddle.shape(x)[2]

return x + y

Expand Down Expand Up @@ -78,6 +89,7 @@ def forward(self, x):


class TestConditionalOp(unittest.TestCase):
@test_with_dygraph_pir
def test_while_op(self):
paddle.disable_static()
net = WhileNet()
Expand All @@ -90,24 +102,36 @@ def test_while_op(self):
)
root_path = tempfile.TemporaryDirectory()
model_file = os.path.join(root_path.name, "while_net")
x = paddle.to_tensor(np.random.random((1, 3, 8, 8)).astype('float32'))
paddle.jit.save(net, model_file)

right_pdmodel = {
"uniform_random",
"shape",
"slice",
"not_equal",
"while",
"elementwise_add",
}
paddle.enable_static()
pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The while op is pruned by mistake.",
)
if paddle.framework.use_pir_api():
program = GetPirModelOp(model_file + ".json")
self.assertEqual(program.global_block().ops[-4].name(), "pd_op.add")
self.assertEqual(
program.global_block().ops[-5].result(1).shape, [1, 3, -1, -1]
)
self.assertEqual(
program.global_block().ops[-5].name(), "pd_op.while"
)
else:
right_pdmodel = {
"uniform_random",
"shape",
"slice",
"not_equal",
"while",
"elementwise_add",
}
pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The while op is pruned by mistake.",
)
root_path.cleanup()

@test_with_dygraph_pir
def test_for_op(self):
paddle.disable_static()
net = ForNet()
Expand All @@ -120,22 +144,31 @@ def test_for_op(self):
model_file = os.path.join(root_path.name, "for_net")
paddle.jit.save(net, model_file)

right_pdmodel = {
"randint",
"fill_constant",
"cast",
"less_than",
"while",
"elementwise_add",
}
paddle.enable_static()
pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The for op is pruned by mistake.",
)
if paddle.framework.use_pir_api():
program = GetPirModelOp(model_file + ".json")
self.assertEqual(program.global_block().ops[-4].name(), "pd_op.add")
self.assertEqual(
program.global_block().ops[-5].name(), "pd_op.while"
)
else:
right_pdmodel = {
"randint",
"fill_constant",
"cast",
"less_than",
"while",
"elementwise_add",
}

pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The for op is pruned by mistake.",
)
root_path.cleanup()

@test_with_dygraph_pir
def test_if_op(self):
paddle.disable_static()
net = IfElseNet()
Expand All @@ -148,20 +181,39 @@ def test_if_op(self):
model_file = os.path.join(root_path.name, "if_net")
paddle.jit.save(net, model_file)

right_pdmodel = {
"assign_value",
"greater_than",
"cast",
"conditional_block",
"logical_not",
"select_input",
}
paddle.enable_static()
pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The if op is pruned by mistake.",
)
if paddle.framework.use_pir_api():
program = GetPirModelOp(model_file + ".json")
op_list = [
"pd_op.data",
"pd_op.full",
"pd_op.assign_value_",
"pd_op.cast",
"pd_op.greater_than",
"pd_op.if",
"pd_op.full",
"pd_op.scale",
"pd_op.fetch",
]
i = 0
for op in program.global_block().ops:
self.assertEqual(op.name(), op_list[i])
i = i + 1
else:
right_pdmodel = {
"assign_value",
"greater_than",
"cast",
"conditional_block",
"logical_not",
"select_input",
}

pdmodel = getModelOp(model_file + ".pdmodel")
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The if op is pruned by mistake.",
)
root_path.cleanup()


Expand Down
4 changes: 0 additions & 4 deletions test/dygraph_to_static/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

import paddle
from paddle.framework import use_pir_api


class BufferLayers(paddle.nn.Layer):
Expand Down Expand Up @@ -98,9 +97,6 @@ def _run(self, to_static):
net = paddle.jit.to_static(net)
x = paddle.rand([16, 10], 'float32')
out = net(x)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if use_pir_api():
return out
if to_static:
load_out = self._test_load(net, x)
np.testing.assert_allclose(
Expand Down
10 changes: 8 additions & 2 deletions test/dygraph_to_static/test_pylayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,6 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@to_legacy_ir_test
def train_and_save_model(self, model_path=None):
layer = SimpleNet_1(784, 20)
example_inputs, layer, _ = train(layer)
Expand All @@ -767,7 +766,14 @@ def test_save_load(self):
loaded_layer = paddle.jit.load(self.model_path)
self.load_and_inference(train_layer, loaded_layer)

@to_legacy_ir_test
@to_pir_test
def test_pir_save_load(self):
# train and save model
train_layer = self.train_and_save_model()
# load model
loaded_layer = paddle.jit.load(self.model_path)
self.load_and_inference(train_layer, loaded_layer)

def load_and_inference(self, train_layer, infer_layer):
train_layer.eval()
infer_layer.eval()
Expand Down
Loading