Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 28 additions & 34 deletions paddle/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,23 @@ using LoDTensor = framework::LoDTensor;

void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);

for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);

CreateScopes(scope);
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
InitMemories(step_scopes[0]);

for (size_t i = 0; i < seq_len_; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i -> stepid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[i], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
}

void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
Expand Down Expand Up @@ -82,21 +84,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}

void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) {
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
auto* boot_mem =
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
} else {
pre_mem->ShareDataWith<float>(*boot_mem);
}
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
pre_mem->ShareDataWith<float>(*boot_mem);
}
}

Expand Down Expand Up @@ -146,23 +144,22 @@ class RecurrentAlgorithmProtoAndCheckerMaker

void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
seq_len_ =
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enforce not null scope.FindVar

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
LinkBootMemoryGradients(step_scopes[0]);
}

void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope, bool infer_shape_mode) const {
Scope* step_scope) const {
for (auto& attr : arg_->memories) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", attr.var);
Expand All @@ -171,11 +168,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
boot_mem_grad->Resize(mem_grad->dims());
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/recurrent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RecurrentAlgorithm {
->GetMutable<std::vector<framework::Scope*>>();
}

void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
void InitMemories(framework::Scope* step_scopes) const;

private:
std::unique_ptr<framework::OperatorBase>* stepnet_;
Expand Down Expand Up @@ -86,8 +86,7 @@ class RecurrentGradientAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;

void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;

protected:
inline const std::vector<framework::Scope*>& GetStepScopes(
Expand Down Expand Up @@ -123,6 +122,7 @@ class RecurrentOp : public framework::OperatorBase {
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}

const OperatorBase& stepnet() const { return *stepnet_; }

static const rnn::ArgumentName kArgName;
Expand Down
55 changes: 23 additions & 32 deletions paddle/operators/rnn/recurrent_op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using LoDTensor = framework::LoDTensor;

void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode) {
const size_t seq_len) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
// global inputs
Expand All @@ -41,51 +41,45 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
if (!infer_shape_mode) {
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1);
}
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1);
step_input->Resize(step_dims);
}
}
}

void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) {
const size_t seq_len) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>();

if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto*

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
}
}

void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
const size_t step_id, const int offset,
bool infer_shape_mode) {
const size_t step_id, const int offset) {
PADDLE_ENFORCE_LT(step_id, scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
step_id, scopes.size());
Expand All @@ -100,11 +94,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
for (auto& attr : memories) {
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
} else {
mem->ShareDataWith<float>(*linked_mem);
}
mem->Resize(linked_mem->dims());
mem->ShareDataWith<float>(*linked_mem);
}
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/rnn/recurrent_op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ struct ArgumentName {
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode);
const size_t seq_len);

/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode);
const size_t seq_len);

void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
const int offset, bool infer_shape_mode);
const int offset);

void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op, bool is_grad = false);
Expand Down
5 changes: 3 additions & 2 deletions paddle/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ class SumOp : public framework::OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
auto x_dims = ctx->GetInputsDim("X");
PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SumOp should not be null.");

auto in_dim = x_dims[0];
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");

auto in_dim = x_dims[0];
for (size_t i = 1; i < N; i++) {
auto dim = x_dims[i];
PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
Expand Down
29 changes: 15 additions & 14 deletions python/paddle/v2/framework/tests/test_recurrent_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ class PySimpleRNN(object):
'''

def __init__(self, input_dim=30, batch_size=50, weight_dim=15, sent_len=11):
self.x = np.random.normal(size=(sent_len, batch_size, input_dim))
self.W = np.random.normal(size=(input_dim, input_dim))
self.U = np.random.normal(size=(input_dim, input_dim))
self.h_boot = np.random.normal(size=(batch_size, input_dim))
self.x = np.random.normal(size=(sent_len, batch_size,
input_dim)).astype("float32")
self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.h_boot = np.random.normal(size=(batch_size,
input_dim)).astype("float32")

# memories
self.mems = [
np.zeros(shape=(batch_size, input_dim)) for i in range(sent_len)
np.zeros(shape=(batch_size, input_dim)).astype("float32")
for i in range(sent_len)
]

def forward(self):
Expand All @@ -36,7 +39,7 @@ def segment_inputs(self):
return [self.x[i] for i in range(self.x.shape[0])]

def concat_outputs(self):
return np.array(self.mems)
return np.array(self.mems).astype("float32")

def step(self, step_id, x):
'''
Expand All @@ -47,8 +50,8 @@ def step(self, step_id, x):
pre_mem = self.mems[step_id - 1]
else:
pre_mem = self.h_boot
xW = np.matmul(x, self.W)
hU = np.matmul(pre_mem, self.U)
xW = np.matmul(x, self.W).astype("float32")
hU = np.matmul(pre_mem, self.U).astype("float32")

sum = xW + hU
self.mems[step_id] = py_sigmoid(sum)
Expand Down Expand Up @@ -102,7 +105,8 @@ def forward(self):
self.create_step_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h@mem").get_tensor())
return np.array(self.scope.find_var("h@mem").get_tensor()).astype(
"float32")

def create_global_variables(self):
# create inlink
Expand Down Expand Up @@ -142,7 +146,7 @@ def create_step_net(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add", X="Wx", Y="Uh", Out="sum")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")

for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
Expand Down Expand Up @@ -179,7 +183,7 @@ def create_forward_op(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add", X="Wx", Y="Uh", Out="sum")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@alias")

for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
Expand All @@ -197,7 +201,4 @@ def test_grad(self):


if __name__ == '__main__':
exit(
0
) # FIXME(yuyang18): InferShape has been removed, this unittest may error
unittest.main()