-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Cond_op with dynamic if-else checked-in #4088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7683e35
cond op
zchen0211 6ab5580
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 c3d684e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 adfef24
tensor element size support
zchen0211 f345b51
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 2c8e795
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 69fb975
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 d8921e9
Fix CI test
reyoung aa90ef9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 b8e75c1
cond op
zchen0211 c7db6e8
cond op passed
zchen0211 f6dee08
new changes
zchen0211 b2d9c91
Merge pull request #1 from reyoung/czy_elemwise
zchen0211 2c8cbb8
if_else_op.md
zchen0211 299dcb6
merge with new change
zchen0211 39d79e6
modified codes
zchen0211 35cc956
merging tensor modify
zchen0211 c557402
cond_op modify
zchen0211 98c3572
remove empty line
zchen0211 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/cond_op.h" | ||
|
|
||
| #include <cstring> | ||
| #include <sstream> | ||
|
|
||
| #include "paddle/framework/op_registry.h" | ||
| #include "paddle/operators/gather.h" | ||
| #include "paddle/operators/net_op.h" | ||
| #include "paddle/operators/scatter.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Scope = framework::Scope; | ||
| using Variable = framework::Variable; | ||
| using Tensor = framework::Tensor; | ||
| using LoDTensor = framework::LoDTensor; | ||
| using DDim = framework::DDim; | ||
|
|
||
| void CondOp::CreateScope(const Scope& scope) const { | ||
| auto sub_scopes_var = scope.FindVar("SubScopes"); | ||
| PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); | ||
| auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>(); | ||
| auto& sub_scope = scope.NewScope(); | ||
| sub_scopes->push_back(&sub_scope); | ||
| } | ||
|
|
||
| void CondOp::CreateIndexTensor(const Scope& scope) const { | ||
| auto index_tensors_var = scope.FindVar("IndexTensors"); | ||
| PADDLE_ENFORCE(index_tensors_var != nullptr, ""); | ||
| auto& index_tensors = | ||
| *index_tensors_var->GetMutable<std::vector<LoDTensor>>(); | ||
| index_tensors.push_back(LoDTensor()); | ||
| } | ||
|
|
||
| void CondOp::InferShape(const Scope& scope) const { | ||
| auto sub_scopes_var = scope.FindVar("SubScopes"); | ||
| PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); | ||
| auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>(); | ||
|
|
||
| for (int i = 0; i < 2; ++i) { | ||
| // Create two sub scopes for true and false branches | ||
| // sub_scopes[0] for the true branch and sub_scopes[1] for the false | ||
| // branch | ||
| CreateScope(scope); | ||
|
|
||
| // Create two tensors for true and false indices | ||
| // index_tensors[0] for the true branch and index_tensors[1] for the false | ||
| // branch | ||
| CreateIndexTensor(scope); | ||
|
|
||
| PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty"); | ||
| for (auto& input : Inputs("Xs")) { | ||
| // Create a new tensor in sub-scope for input-type tensor | ||
| Variable* v = sub_scopes[i]->NewVar(input); | ||
| LoDTensor* sub_input = v->GetMutable<LoDTensor>(); | ||
| sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims()); | ||
| } | ||
|
|
||
| for (auto& output : (*sub_net_op_[i]).Outputs()) { | ||
| for (auto& var_name : output.second) { | ||
| sub_scopes[i]->NewVar(var_name); | ||
| } | ||
| } | ||
|
|
||
| // each net calls InferShape | ||
| sub_net_op_[i]->InferShape(*sub_scopes[i]); | ||
| } | ||
|
|
||
| for (auto& output : Outputs("Outs")) { | ||
| LoDTensor* tensor_t_out = | ||
| sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL"); | ||
| LoDTensor* tensor_f_out = | ||
| sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL"); | ||
|
|
||
| auto* tensor_out_var = scope.FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); | ||
| LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, | ||
| "True output tensor should not be NULL"); | ||
|
|
||
| // check output size should be same | ||
| PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), | ||
| "Outputs not of the same shape"); | ||
| tensor_out->Resize(tensor_t_out->dims()); | ||
| // tensor_out->mutable_data<float>(tensor_out->dims(), | ||
| // platform::CPUPlace()); | ||
| tensor_out->mutable_data<float>(platform::CPUPlace()); | ||
| } | ||
| } | ||
|
|
||
| void CondOp::Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const { | ||
| auto* sub_scopes_var = scope.FindVar("SubScopes"); | ||
| auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>(); | ||
| auto* index_tensors_var = scope.FindVar("IndexTensors"); | ||
| auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>(); | ||
|
|
||
| std::string cond_name = Input("Cond"); | ||
| Variable* cond_var = scope.FindVar(cond_name); | ||
| PADDLE_ENFORCE_NOT_NULL(cond_var); | ||
| const LoDTensor* cond = cond_var->GetMutable<LoDTensor>(); | ||
|
|
||
| // Step 1: get the true/false index at runtime | ||
| // index_[0]: vector<int>, contains all index for cond[i] == true | ||
| // index_[1]: vector<int>, contains all index for cond[i] == false | ||
| for (int i = 0; i < 2; ++i) index_[i].clear(); | ||
|
|
||
| const int* cond_data = cond->data<int>(); | ||
| for (int i = 0; i < cond->dims()[0]; ++i) { | ||
| if (cond_data[i]) | ||
| index_[0].push_back(i); | ||
| else | ||
| index_[1].push_back(i); | ||
| } | ||
|
|
||
| // put index_[0] and index_[1] into two tensors: | ||
| // index_tensor_[0] and index_tensor_[1] | ||
| DDim dim = paddle::framework::make_ddim({0}); | ||
| for (int i = 0; i < 2; ++i) { | ||
| dim[0] = index_[i].size(); | ||
| int* tmp_ptr = | ||
| index_tensors[i].mutable_data<int>(dim, platform::CPUPlace()); | ||
| index_tensors[i].Resize(dim); | ||
| memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); | ||
| } | ||
|
|
||
| // Step 2: collect data by calling gather | ||
| for (int i = 0; i < 2; ++i) { | ||
| // i= 0/i for True and False branches respectively | ||
| for (auto& input : Inputs("Xs")) { | ||
| // find Tensor | ||
| Variable* v = scope.FindVar(input); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_parent = v->GetMutable<LoDTensor>(); | ||
|
|
||
| v = sub_scopes[i]->FindVar(input); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_child = v->GetMutable<LoDTensor>(); | ||
|
|
||
| // Resize child | ||
| DDim dim = tensor_child->dims(); | ||
| dim[0] = index_[i].size(); | ||
| tensor_child->Resize(dim); | ||
| tensor_child->mutable_data<float>(dim, platform::CPUPlace()); | ||
|
|
||
| Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], | ||
| tensor_child); | ||
| } | ||
| } | ||
|
|
||
| // Step 3: run | ||
| for (int i = 0; i < 2; ++i) { | ||
| sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); | ||
| } | ||
|
|
||
| // Step 4: merge output results | ||
| for (int i = 0; i < 2; ++i) { | ||
| // i= 0/i for True and False branches respectively | ||
| for (auto& output : Outputs("Outs")) { | ||
| // find Tensor | ||
| Variable* v = scope.FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_parent = v->GetMutable<LoDTensor>(); | ||
|
|
||
| v = sub_scopes[i]->FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_child = v->GetMutable<LoDTensor>(); | ||
|
|
||
| ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], | ||
| tensor_parent); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| CondOpProtoAndCheckerMaker(framework::OpProto* proto, | ||
| framework::OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("Cond", "The condition, which is a bool vector"); | ||
| AddInput("Xs", "Inputs of Subnets").AsDuplicable(); | ||
| AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable(); | ||
|
|
||
| AddOutput("SubScopes", "sub scopes for true and false branches"); | ||
| AddOutput("IndexTensors", "Index Tensors contains indices for true/false"); | ||
|
|
||
| AddComment(R"DOC( | ||
| Sample dependent Cond Operator: | ||
| Given Cond[i] as a 1/0 vector to indicate true/false | ||
| The equation is: | ||
| Out[i] = subnet_t[i], if Cond[i] == true | ||
| Out[i] = subnet_t[i], if Cond[i] == false | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp, | ||
| paddle::operators::CondOpProtoAndCheckerMaker); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <vector> | ||
| #include "glog/logging.h" | ||
| #include "paddle/framework/ddim.h" | ||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/operator.h" | ||
| #include "paddle/framework/tensor.h" | ||
| #include "paddle/operators/net_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some doc here. |
||
| /* | ||
| * @brief CondOp is a dynamic if-else Operator | ||
| * | ||
| * It has a input tensor named cond indicating which netop each instance will | ||
| * run. | ||
| * | ||
| * if cond == 1, it will run true_net, which is a NetOp. | ||
| * | ||
| * if cond == 0, it will run false_net, which is another NetOp. | ||
| */ | ||
| class CondOp : public framework::OperatorBase { | ||
| public: | ||
| CondOp(const std::string& type, const framework::VariableNameMap& inputs, | ||
| const framework::VariableNameMap& outputs, | ||
| const framework::AttributeMap& attrs) | ||
| : OperatorBase(type, inputs, outputs, attrs) { | ||
| index_.resize(2); | ||
| sub_net_op_.resize(2); | ||
| } | ||
|
|
||
| CondOp(const CondOp& o) | ||
| : framework::OperatorBase( | ||
| static_cast<const framework::OperatorBase&>(o)) { | ||
| // TODO(yuyang18): Implement copy ctor well. | ||
| PADDLE_THROW("Not implemented"); | ||
| } | ||
|
|
||
| void CreateScope(const framework::Scope& scope) const; | ||
|
|
||
| void CreateIndexTensor(const framework::Scope& scope) const; | ||
|
|
||
| /* | ||
| * InferShape must be called before Run. | ||
| */ | ||
| void InferShape(const framework::Scope& scope) const override; | ||
|
|
||
| /* | ||
| * Set True Block | ||
| */ | ||
| void set_truenet(std::unique_ptr<OperatorBase>&& net) { | ||
| sub_net_op_[0] = std::move(net); | ||
| } | ||
|
|
||
| /* | ||
| * Set False Block | ||
| */ | ||
| void set_falsenet(std::unique_ptr<OperatorBase>&& net) { | ||
| sub_net_op_[1] = std::move(net); | ||
| } | ||
|
|
||
| void Run(const framework::Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const override; | ||
|
|
||
| private: | ||
| // sub_net_op_[0]: subnet_t | ||
| // sub_net_op_[1]: subnet_f | ||
| std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_; | ||
|
|
||
| // index_[0]: True_index; | ||
| // index_[1]: False_index; | ||
| mutable std::vector<std::vector<int>> index_; | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recurrent operator also uses explicit type float in the implementation.
CondOphas the same problem. Maybe theRecurrentOpandCondOpshould like:Then specialized the class.