-
Notifications
You must be signed in to change notification settings - Fork 5.9k
refactor rnn infershape #4553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor rnn infershape #4553
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,21 +30,23 @@ using LoDTensor = framework::LoDTensor; | |
|
|
||
| void RecurrentAlgorithm::Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const { | ||
| auto step_scopes = GetStepScopes(scope); | ||
| rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, | ||
| false /*infer_shape_mode*/); | ||
| InitMemories(step_scopes[0], false /*infer_shape_mode*/); | ||
|
|
||
| for (size_t step_id = 0; step_id < seq_len_; step_id++) { | ||
| // create output alias variables | ||
| if (step_id > 0) { | ||
| rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, | ||
| false /*infer_shape_mode*/); | ||
| auto* input0 = scope.FindVar(arg_->inlinks[0]); | ||
| PADDLE_ENFORCE_NOT_NULL(input0); | ||
| seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0]; | ||
| PADDLE_ENFORCE_GT(seq_len_, 0); | ||
|
|
||
| CreateScopes(scope); | ||
| auto& step_scopes = GetStepScopes(scope); | ||
| rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); | ||
| InitMemories(step_scopes[0]); | ||
|
|
||
| for (size_t i = 0; i < seq_len_; i++) { | ||
| if (i > 0) { | ||
| rnn::LinkMemories(step_scopes, arg_->memories, i, -1); | ||
| } | ||
| (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); | ||
| (*stepnet_)->Run(*step_scopes[i], dev_ctx); | ||
| } | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, | ||
| false /*infer_shape_mode*/); | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); | ||
| } | ||
|
|
||
| void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { | ||
|
|
@@ -82,21 +84,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { | |
| } | ||
| } | ||
|
|
||
| void RecurrentAlgorithm::InitMemories(Scope* step_scope, | ||
| bool infer_shape_mode) const { | ||
| void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { | ||
| for (auto& attr : arg_->memories) { | ||
| auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, | ||
| "memory [%s]'s boot variable [%s] not exists", attr.var, | ||
| attr.boot_var); | ||
| auto* boot_mem = | ||
| step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>(); | ||
| if (infer_shape_mode) { | ||
| pre_mem->Resize(boot_mem->dims()); | ||
| PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2); | ||
| } else { | ||
| pre_mem->ShareDataWith<float>(*boot_mem); | ||
| } | ||
| pre_mem->Resize(boot_mem->dims()); | ||
| PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2); | ||
| pre_mem->ShareDataWith<float>(*boot_mem); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -146,23 +144,22 @@ class RecurrentAlgorithmProtoAndCheckerMaker | |
|
|
||
| void RecurrentGradientAlgorithm::Run( | ||
| const Scope& scope, const platform::DeviceContext& dev_ctx) const { | ||
| seq_len_ = | ||
| scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0]; | ||
|
||
| auto step_scopes = GetStepScopes(scope); | ||
| rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, | ||
| false /*infer_shape_mode*/); | ||
| rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); | ||
| for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { | ||
| if (static_cast<size_t>(step_id) != seq_len_ - 1) { | ||
| rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, | ||
| false /*infer_shape_mode*/); | ||
| rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); | ||
| } | ||
| (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); | ||
| } | ||
| LinkBootMemoryGradients(step_scopes[0], false); | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, | ||
| false /*infer_shape_mode*/); | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); | ||
| LinkBootMemoryGradients(step_scopes[0]); | ||
| } | ||
|
|
||
| void RecurrentGradientAlgorithm::LinkBootMemoryGradients( | ||
| Scope* step_scope, bool infer_shape_mode) const { | ||
| Scope* step_scope) const { | ||
| for (auto& attr : arg_->memories) { | ||
| PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, | ||
| "memory variable [%s] does not exists", attr.var); | ||
|
|
@@ -171,11 +168,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( | |
| auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>(); | ||
| auto* boot_mem_grad = | ||
| step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>(); | ||
| if (infer_shape_mode) { | ||
| boot_mem_grad->Resize(mem_grad->dims()); | ||
| } else { | ||
| boot_mem_grad->ShareDataWith<float>(*mem_grad); | ||
| } | ||
| boot_mem_grad->Resize(mem_grad->dims()); | ||
| boot_mem_grad->ShareDataWith<float>(*mem_grad); | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ using LoDTensor = framework::LoDTensor; | |
|
|
||
| void SegmentInputs(const std::vector<Scope*>& step_scopes, | ||
| const std::vector<std::string>& inlinks, | ||
| const size_t seq_len, bool infer_shape_mode) { | ||
| const size_t seq_len) { | ||
| PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); | ||
| for (size_t i = 0; i < inlinks.size(); ++i) { | ||
| // global inputs | ||
|
|
@@ -41,51 +41,45 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, | |
| for (size_t j = 0; j < seq_len; j++) { | ||
| Tensor* step_input = | ||
| step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>(); | ||
| if (!infer_shape_mode) { | ||
| // The input of operators of each step is Tensor here. | ||
| // Maybe need to modify Slice function. | ||
| *step_input = input->Slice<float>(j, j + 1); | ||
| } | ||
| // The input of operators of each step is Tensor here. | ||
| // Maybe need to modify Slice function. | ||
| *step_input = input->Slice<float>(j, j + 1); | ||
| step_input->Resize(step_dims); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void ConcatOutputs(const std::vector<Scope*>& step_scopes, | ||
| const std::vector<std::string>& outlinks, | ||
| const size_t seq_len, bool infer_shape_mode) { | ||
| const size_t seq_len) { | ||
| for (size_t i = 0; i < outlinks.size(); i++) { | ||
| auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]); | ||
| PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.", | ||
| outlinks[i]); | ||
| LoDTensor* output = output_var->GetMutable<LoDTensor>(); | ||
|
|
||
| if (infer_shape_mode) { | ||
| auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]); | ||
| PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); | ||
| f::DDim step_dims = | ||
| step_scope_var->template GetMutable<LoDTensor>()->dims(); | ||
| std::vector<int64_t> dims_vec = vectorize(step_dims); | ||
| dims_vec.insert(dims_vec.begin(), seq_len); | ||
| output->Resize(f::make_ddim(dims_vec)); | ||
| } else { | ||
| output->mutable_data<float>(platform::CPUPlace()); | ||
| for (size_t j = 0; j < seq_len; j++) { | ||
| LoDTensor* step_output = | ||
| step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>(); | ||
| // TODO(luotao02) data type and platform::DeviceContext() should set | ||
| // correctly | ||
| (output->Slice<float>(j, j + 1)) | ||
| .CopyFrom<float>(*step_output, platform::CPUPlace()); | ||
| } | ||
| auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]); | ||
|
||
| PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); | ||
| f::DDim step_dims = | ||
| step_scope_var->template GetMutable<LoDTensor>()->dims(); | ||
| std::vector<int64_t> dims_vec = vectorize(step_dims); | ||
| dims_vec.insert(dims_vec.begin(), seq_len); | ||
| output->Resize(f::make_ddim(dims_vec)); | ||
| output->mutable_data<float>(platform::CPUPlace()); | ||
| for (size_t j = 0; j < seq_len; j++) { | ||
| LoDTensor* step_output = | ||
| step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>(); | ||
| // TODO(luotao02) data type and platform::DeviceContext() should set | ||
| // correctly | ||
| (output->Slice<float>(j, j + 1)) | ||
| .CopyFrom<float>(*step_output, platform::CPUPlace()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void LinkMemories(const std::vector<Scope*>& scopes, | ||
| const std::vector<rnn::MemoryAttr>& memories, | ||
| const size_t step_id, const int offset, | ||
| bool infer_shape_mode) { | ||
| const size_t step_id, const int offset) { | ||
| PADDLE_ENFORCE_LT(step_id, scopes.size(), | ||
| "step [%d] is out of range of step scopes' size [%d]", | ||
| step_id, scopes.size()); | ||
|
|
@@ -100,11 +94,8 @@ void LinkMemories(const std::vector<Scope*>& scopes, | |
| for (auto& attr : memories) { | ||
| auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>(); | ||
| auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>(); | ||
| if (infer_shape_mode) { | ||
| mem->Resize(linked_mem->dims()); | ||
| } else { | ||
| mem->ShareDataWith<float>(*linked_mem); | ||
| } | ||
| mem->Resize(linked_mem->dims()); | ||
| mem->ShareDataWith<float>(*linked_mem); | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i -> stepid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done