-
Notifications
You must be signed in to change notification settings - Fork 6k
Support memory eager deletion on recurrent OP #17710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
e51e7ba
8a6db0e
e9da08a
ba5ea07
90e1519
38ba259
dd9d940
10746cf
e98f566
8f7373d
ce52e80
446bc9f
eb8efa5
afb8c3e
cd1c2eb
8a54181
fc3dd8e
c4fb071
ba43599
bff0367
5e9d853
4c9537e
fb905a2
36df9b1
5dd6470
57a9fca
0b89ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h" | ||
|
|
||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include "paddle/fluid/framework/details/computation_op_handle.h" | ||
| #include "paddle/fluid/framework/details/multi_devices_helper.h" | ||
| #include "paddle/fluid/framework/ir/graph_helper.h" | ||
| #include "paddle/fluid/string/string_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace ir { | ||
|
|
||
| using paddle::operators::OpVariant; | ||
| using paddle::operators::OpVariantSet; | ||
| using paddle::operators::OpAndGradOpPair; | ||
|
|
||
| void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { | ||
| // Find all recurrent_op and recurrent_grad_op in graph | ||
| // Note the graph only contains ops and block 0 | ||
| std::unordered_map<size_t, OpAndGradOpPair> target_ops = | ||
| DeviceIdToRecurrentAndRecurrentGradOp(*graph); | ||
|
|
||
| for (auto &entry : target_ops) { | ||
| // Prepare safe eager deletion on different devices because the garbage | ||
| // collection may be different across devices | ||
| OpAndGradOpPair &op_pair = entry.second; | ||
| PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); | ||
| } | ||
| } | ||
|
|
||
| // Returns a std::unordered_map mapping from the device id to recurrent op and | ||
| // grad op pair | ||
| std::unordered_map<size_t, OpAndGradOpPair> | ||
| RecurrentOpEagerDeletionPass::DeviceIdToRecurrentAndRecurrentGradOp( | ||
| const Graph &graph) const { | ||
| std::unordered_map<size_t, OpAndGradOpPair> ret; | ||
| std::vector<details::OpHandleBase *> all_ops = | ||
| FilterByNodeWrapper<details::OpHandleBase>(graph); | ||
|
|
||
| for (auto *op : all_ops) { | ||
| auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op); | ||
| if (compute_op == nullptr) continue; | ||
|
|
||
| if (compute_op->Name() == "recurrent") { | ||
| // GetScopeIdx() returns device/place id | ||
| ret[compute_op->GetScopeIdx()].first.emplace(compute_op->GetOp()); | ||
| } else if (compute_op->Name() == "recurrent_grad") { | ||
| // GetScopeIdx() returns device/place id | ||
| ret[compute_op->GetScopeIdx()].second.emplace(compute_op->GetOp()); | ||
| } | ||
| } | ||
| return ret; | ||
| } | ||
|
|
||
| } // namespace ir | ||
| } // namespace framework | ||
| } // namespace paddle | ||
|
|
||
| REGISTER_PASS(recurrent_op_eager_deletion_pass, | ||
| paddle::framework::ir::RecurrentOpEagerDeletionPass); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommend to remove this header file.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reply to you offline. |
||
|
|
||
| #include <unordered_map> | ||
|
|
||
| #include "paddle/fluid/framework/details/computation_op_handle.h" | ||
| #include "paddle/fluid/framework/details/multi_devices_helper.h" | ||
| #include "paddle/fluid/framework/ir/graph_helper.h" | ||
| #include "paddle/fluid/operators/controlflow/op_variant.h" | ||
| #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace ir { | ||
|
|
||
| // Pass class set skip eager deletion vars for recurrent ops | ||
| class RecurrentOpEagerDeletionPass : public Pass { | ||
| protected: | ||
| void ApplyImpl(Graph *graph) const override; | ||
|
|
||
| private: | ||
| // Returns a std::unordered_map mapping from the device id to recurrent op and | ||
| // grad op pair | ||
| std::unordered_map<size_t, paddle::operators::OpAndGradOpPair> | ||
| DeviceIdToRecurrentAndRecurrentGradOp(const Graph &graph) const; | ||
| }; | ||
|
|
||
| } // namespace ir | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| include(operators) | ||
| register_operators(DEPS naive_executor) | ||
| cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator) | ||
| cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc) | ||
| cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant) | ||
| cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant) | ||
|
|
||
| file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/operators/controlflow/op_variant.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| struct InputsVisitor | ||
| : public boost::static_visitor<const framework::VariableNameMap *> { | ||
| template <typename OpType> | ||
| const framework::VariableNameMap *operator()(const OpType *op) const { | ||
| return &(op->Inputs()); | ||
| } | ||
| }; | ||
|
|
||
| struct OutputsVisitor | ||
| : public boost::static_visitor<const framework::VariableNameMap *> { | ||
| template <typename OpType> | ||
| const framework::VariableNameMap *operator()(const OpType *op) const { | ||
| return &(op->Outputs()); | ||
| } | ||
| }; | ||
|
|
||
| struct AttributeMapVisitor | ||
| : public boost::static_visitor<const framework::AttributeMap *> { | ||
| const framework::AttributeMap *operator()(const framework::OpDesc *op) const { | ||
| return &(op->GetAttrMap()); | ||
| } | ||
|
|
||
| const framework::AttributeMap *operator()( | ||
| const framework::OperatorBase *op) const { | ||
| return &(op->Attrs()); | ||
| } | ||
| }; | ||
|
|
||
| struct RawPointerVisitor : public boost::static_visitor<const void *> { | ||
| template <typename OpType> | ||
| const void *operator()(const OpType *op) const { | ||
| return op; | ||
| } | ||
| }; | ||
|
|
||
| const framework::VariableNameMap &OpVariant::Inputs() const { | ||
| return *boost::apply_visitor(InputsVisitor(), op_); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about move
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can also move
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| } | ||
|
|
||
| const framework::VariableNameMap &OpVariant::Outputs() const { | ||
| return *boost::apply_visitor(OutputsVisitor(), op_); | ||
| } | ||
|
|
||
| const framework::AttributeMap &OpVariant::Attrs() const { | ||
| return *boost::apply_visitor(AttributeMapVisitor(), op_); | ||
| } | ||
|
|
||
| const void *OpVariant::RawPointer() const { | ||
| return boost::apply_visitor(RawPointerVisitor(), op_); | ||
| } | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <string> | ||
|
|
||
| #include "paddle/fluid/framework/operator.h" | ||
| #include "paddle/fluid/framework/program_desc.h" | ||
| #include "paddle/fluid/platform/variant.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| // OpVariant is a wrapper class of OpDesc and OperatorBase pointer | ||
| // So that API would be the same. | ||
| class OpVariant { | ||
| public: | ||
| OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT | ||
|
|
||
| OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT | ||
|
|
||
| const framework::VariableNameMap &Inputs() const; | ||
|
|
||
| const framework::VariableNameMap &Outputs() const; | ||
|
|
||
| const framework::AttributeMap &Attrs() const; | ||
|
|
||
| const void *RawPointer() const; | ||
|
|
||
| template <typename AttrType> | ||
| const AttrType &Attr(const std::string &name) const { | ||
| auto &attrs = Attrs(); | ||
| auto it = attrs.find(name); | ||
| PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name); | ||
| return boost::get<AttrType>(it->second); | ||
| } | ||
|
|
||
| bool operator==(const OpVariant &other) const { | ||
| return RawPointer() == other.RawPointer(); | ||
| } | ||
|
|
||
| int which() const { return static_cast<int>(op_.which()); } | ||
|
|
||
| struct Hasher { | ||
| size_t operator()(const OpVariant &op) const { | ||
| return reinterpret_cast<size_t>(op.RawPointer()); | ||
| } | ||
| }; | ||
|
|
||
| private: | ||
| const boost::variant<const framework::OperatorBase *, | ||
| const framework::OpDesc *> | ||
| op_; | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add dependency of
recurrent_op_eager_deletion_pass. What's more important, you should applyrecurrent_op_eager_deletion_passinsideeager_deletion_pass. See here. Please add unittests usingParallelExecutorto verify it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.