-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Cond_op with dynamic if-else checked-in #4088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
7683e35
6ab5580
c3d684e
adfef24
f345b51
2c8e795
69fb975
d8921e9
aa90ef9
b8e75c1
c7db6e8
f6dee08
b2d9c91
2c8cbb8
299dcb6
39d79e6
35cc956
c557402
98c3572
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/cond_op.h" | ||
|
|
||
| #include <cstring> | ||
| #include <sstream> | ||
|
|
||
| #include "paddle/framework/op_registry.h" | ||
| #include "paddle/operators/gather.h" | ||
| #include "paddle/operators/net_op.h" | ||
| #include "paddle/operators/scatter.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Scope = framework::Scope; | ||
| using Variable = framework::Variable; | ||
| using Tensor = framework::Tensor; | ||
| using LoDTensor = framework::LoDTensor; | ||
| using DDim = framework::DDim; | ||
|
|
||
| void CondOp::CreateScope(const Scope& scope) const { | ||
| auto sub_scopes_var = scope.FindVar("SubScopes"); | ||
| PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); | ||
| auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>(); | ||
| auto& sub_scope = scope.NewScope(); | ||
| sub_scopes->push_back(&sub_scope); | ||
| } | ||
|
|
||
| void CondOp::CreateIndexTensor(const Scope& scope) const { | ||
| auto index_tensors_var = scope.FindVar("IndexTensors"); | ||
| PADDLE_ENFORCE(index_tensors_var != nullptr, ""); | ||
| auto& index_tensors = | ||
| *index_tensors_var->GetMutable<std::vector<LoDTensor>>(); | ||
| index_tensors.push_back(LoDTensor()); | ||
| } | ||
|
|
||
| void CondOp::InferShape(const Scope& scope) const { | ||
| auto sub_scopes_var = scope.FindVar("SubScopes"); | ||
| PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); | ||
| auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>(); | ||
|
|
||
| for (int i = 0; i < 2; ++i) { | ||
| // Create two sub scopes for true and false branches | ||
| // sub_scopes[0] for the true branch and sub_scopes[1] for the false | ||
| // branch | ||
| CreateScope(scope); | ||
|
|
||
| // Create two tensors for true and false indices | ||
| // index_tensors[0] for the true branch and index_tensors[1] for the false | ||
| // branch | ||
| CreateIndexTensor(scope); | ||
|
|
||
| PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty"); | ||
| for (auto& input : Inputs("Xs")) { | ||
| // Create a new tensor in sub-scope for input-type tensor | ||
| Variable* v = sub_scopes[i]->NewVar(input); | ||
| LoDTensor* sub_input = v->GetMutable<LoDTensor>(); | ||
| sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims()); | ||
| } | ||
|
|
||
| for (auto& output : (*sub_net_op_[i]).Outputs()) { | ||
| for (auto& var_name : output.second) { | ||
| sub_scopes[i]->NewVar(var_name); | ||
| } | ||
| } | ||
|
|
||
| // each net calls InferShape | ||
| sub_net_op_[i]->InferShape(*sub_scopes[i]); | ||
| } | ||
|
|
||
| for (auto& output : Outputs("Outs")) { | ||
| LoDTensor* tensor_t_out = | ||
| sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL"); | ||
| LoDTensor* tensor_f_out = | ||
| sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL"); | ||
|
|
||
| auto* tensor_out_var = scope.FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); | ||
| LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, | ||
| "True output tensor should not be NULL"); | ||
|
|
||
| // check output size should be same | ||
| PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), | ||
| "Outputs not of the same shape"); | ||
| tensor_out->Resize(tensor_t_out->dims()); | ||
| // tensor_out->mutable_data<float>(tensor_out->dims(), | ||
| // platform::CPUPlace()); | ||
| tensor_out->mutable_data<float>(platform::CPUPlace()); | ||
| } | ||
| } | ||
|
|
||
| void CondOp::Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const { | ||
| auto* sub_scopes_var = scope.FindVar("SubScopes"); | ||
| auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>(); | ||
| auto* index_tensors_var = scope.FindVar("IndexTensors"); | ||
| auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>(); | ||
|
|
||
| std::string cond_name = Input("Cond"); | ||
| Variable* cond_var = scope.FindVar(cond_name); | ||
| PADDLE_ENFORCE_NOT_NULL(cond_var); | ||
| const LoDTensor* cond = cond_var->GetMutable<LoDTensor>(); | ||
|
|
||
| // Step 1: get the true/false index at runtime | ||
| // index_[0]: vector<int>, contains all index for cond[i] == true | ||
| // index_[1]: vector<int>, contains all index for cond[i] == false | ||
| for (int i = 0; i < 2; ++i) index_[i].clear(); | ||
|
|
||
| const int* cond_data = cond->data<int>(); | ||
| for (int i = 0; i < cond->dims()[0]; ++i) { | ||
| if (cond_data[i]) | ||
| index_[0].push_back(i); | ||
| else | ||
| index_[1].push_back(i); | ||
| } | ||
|
|
||
| // put index_[0] and index_[1] into two tensors: | ||
| // index_tensor_[0] and index_tensor_[1] | ||
| DDim dim = paddle::framework::make_ddim({0}); | ||
| for (int i = 0; i < 2; ++i) { | ||
| dim[0] = index_[i].size(); | ||
| int* tmp_ptr = | ||
| index_tensors[i].mutable_data<int>(dim, platform::CPUPlace()); | ||
| index_tensors[i].Resize(dim); | ||
| memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); | ||
| } | ||
|
|
||
| // Step 2: collect data by calling gather | ||
| for (int i = 0; i < 2; ++i) { | ||
| // i= 0/i for True and False branches respectively | ||
| for (auto& input : Inputs("Xs")) { | ||
| // find Tensor | ||
| Variable* v = scope.FindVar(input); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_parent = v->GetMutable<LoDTensor>(); | ||
|
|
||
| v = sub_scopes[i]->FindVar(input); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_child = v->GetMutable<LoDTensor>(); | ||
|
|
||
| // Resize child | ||
| DDim dim = tensor_child->dims(); | ||
| dim[0] = index_[i].size(); | ||
| tensor_child->Resize(dim); | ||
| tensor_child->mutable_data<float>(dim, platform::CPUPlace()); | ||
|
|
||
| Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], | ||
| tensor_child); | ||
| } | ||
| } | ||
|
|
||
| // Step 3: run | ||
| for (int i = 0; i < 2; ++i) { | ||
| sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); | ||
| } | ||
|
|
||
| // Step 4: merge output results | ||
| for (int i = 0; i < 2; ++i) { | ||
| // i= 0/i for True and False branches respectively | ||
| for (auto& output : Outputs("Outs")) { | ||
| // find Tensor | ||
| Variable* v = scope.FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_parent = v->GetMutable<LoDTensor>(); | ||
|
|
||
| v = sub_scopes[i]->FindVar(output); | ||
| PADDLE_ENFORCE_NOT_NULL(v); | ||
| LoDTensor* tensor_child = v->GetMutable<LoDTensor>(); | ||
|
|
||
| ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], | ||
| tensor_parent); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| CondOpProtoAndCheckerMaker(framework::OpProto* proto, | ||
| framework::OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("Cond", "The condition, which is a bool vector"); | ||
| AddInput("Xs", "Inputs of Subnets").AsDuplicable(); | ||
| AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable(); | ||
|
|
||
| AddOutput("SubScopes", "sub scopes for true and false branches"); | ||
| AddOutput("IndexTensors", "Index Tensors contains indices for true/false"); | ||
|
|
||
| AddComment(R"DOC( | ||
| Sample dependent Cond Operator: | ||
| Given Cond[i] as a 1/0 vector to indicate true/false | ||
| The equation is: | ||
| Out[i] = subnet_t[i], if Cond[i] == true | ||
| Out[i] = subnet_t[i], if Cond[i] == false | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp, | ||
| paddle::operators::CondOpProtoAndCheckerMaker); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <vector> | ||
| #include "glog/logging.h" | ||
| #include "paddle/framework/ddim.h" | ||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/operator.h" | ||
| #include "paddle/framework/tensor.h" | ||
| #include "paddle/operators/net_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some doc here. |
||
| /* | ||
| * @brief CondOp is a dynamic if-else Operator | ||
| * | ||
| * It has a input tensor named cond indicating which netop each instance will | ||
| * run. | ||
| * | ||
| * if cond == 1, it will run true_net, which is a NetOp. | ||
| * | ||
| * if cond == 0, it will run false_net, which is another NetOp. | ||
| */ | ||
|
|
||
|
||
| class CondOp : public framework::OperatorBase { | ||
| public: | ||
| CondOp(const std::string& type, const framework::VariableNameMap& inputs, | ||
| const framework::VariableNameMap& outputs, | ||
| const framework::AttributeMap& attrs) | ||
| : OperatorBase(type, inputs, outputs, attrs) { | ||
| index_.resize(2); | ||
| sub_net_op_.resize(2); | ||
| } | ||
|
|
||
| CondOp(const CondOp& o) | ||
| : framework::OperatorBase( | ||
| static_cast<const framework::OperatorBase&>(o)) { | ||
| // TODO(yuyang18): Implement copy ctor well. | ||
| PADDLE_THROW("Not implemented"); | ||
| } | ||
|
|
||
| void CreateScope(const framework::Scope& scope) const; | ||
|
|
||
| void CreateIndexTensor(const framework::Scope& scope) const; | ||
|
|
||
| /* | ||
| * InferShape must be called before Run. | ||
| */ | ||
| void InferShape(const framework::Scope& scope) const override; | ||
|
|
||
| // Set True Block | ||
|
||
| void set_truenet(std::unique_ptr<OperatorBase>&& net) { | ||
| sub_net_op_[0] = std::move(net); | ||
| } | ||
|
|
||
| // Set False Block | ||
| void set_falsenet(std::unique_ptr<OperatorBase>&& net) { | ||
| sub_net_op_[1] = std::move(net); | ||
| } | ||
|
|
||
| void Run(const framework::Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const override; | ||
|
|
||
| private: | ||
| // sub_net_op_[0]: subnet_t | ||
| // sub_net_op_[1]: subnet_f | ||
| std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_; | ||
|
|
||
| // index_[0]: True_index; | ||
| // index_[1]: False_index; | ||
| mutable std::vector<std::vector<int>> index_; | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recurrent operator also uses explicit type float in the implementation.
CondOphas the same problem. Maybe theRecurrentOpandCondOpshould like:Then specialized the class.