Skip to content

Commit 4c7a9a4

Browse files
authored
Merge pull request #4088 from zchen0211/develop
Cond_op with dynamic if-else checked-in
2 parents 13d0005 + 98c3572 commit 4c7a9a4

9 files changed

Lines changed: 479 additions & 25 deletions

File tree

doc/design/if_else_op.md

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,4 @@
1-
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`.
2-
3-
```python
4-
import paddle as pd
5-
6-
x = var()
7-
y = var()
8-
cond = var()
9-
10-
b = pd.create_ifop(inputs=[x], output_num=1)
11-
with b.true_block():
12-
x = b.inputs(0)
13-
z = operator.add(x, y)
14-
b.set_output(0, operator.softmax(z))
15-
16-
out = b(cond)
17-
```
18-
19-
If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
1+
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has N instances. If cond[i] == True, input instance input[i] will go through true_block() and generate output[i]; otherwise it will produce output from false_bloack().
202

213
```python
224
import paddle as pd
@@ -39,7 +21,7 @@ with b.false_block():
3921
out = b(cond)
4022
```
4123

42-
If only true_block is set in an IfElseOp, we can have a default value for false as:
24+
If only true_block is set in an IfElseOp, a special case is that we can have a default value for false as:
4325
```python
4426
import paddle as pd
4527

paddle/framework/tensor_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace framework {
2222
template <typename T>
2323
inline void Tensor::check_memory_size() const {
2424
PADDLE_ENFORCE_NOT_NULL(
25-
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
25+
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
2626
PADDLE_ENFORCE_GE(
2727
holder_->size(), numel() * sizeof(T) + offset_,
2828
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "

paddle/framework/tensor_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) {
3636
} catch (paddle::platform::EnforceNotMet err) {
3737
caught = true;
3838
std::string msg =
39-
"holder_ should not be null\nTenosr holds no memory. Call "
39+
"holder_ should not be null\nTensor holds no memory. Call "
4040
"Tensor::mutable_data first.";
4141
const char* what = err.what();
4242
for (size_t i = 0; i < msg.length(); ++i) {
@@ -112,7 +112,7 @@ TEST(Tensor, ShareDataWith) {
112112
} catch (paddle::platform::EnforceNotMet err) {
113113
caught = true;
114114
std::string msg =
115-
"holder_ should not be null\nTenosr holds no memory. Call "
115+
"holder_ should not be null\nTensor holds no memory. Call "
116116
"Tensor::mutable_data first.";
117117
const char* what = err.what();
118118
for (size_t i = 0; i < msg.length(); ++i) {
@@ -274,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) {
274274
Tensor res = ReshapeToMatrix<int>(src, 2);
275275
ASSERT_EQ(res.dims()[0], 2 * 3);
276276
ASSERT_EQ(res.dims()[1], 4 * 9);
277-
}
277+
}

paddle/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@ endfunction()
8080
add_subdirectory(math)
8181

8282
set(DEPS_OPS
83-
recurrent_op)
83+
recurrent_op
84+
cond_op)
8485
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
8586
DEPS framework_proto tensor net_op)
87+
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
8688

8789
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
8890
foreach(src ${GENERAL_OPS})

paddle/operators/cond_op.cc

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/cond_op.h"
16+
17+
#include <cstring>
18+
#include <sstream>
19+
20+
#include "paddle/framework/op_registry.h"
21+
#include "paddle/operators/gather.h"
22+
#include "paddle/operators/net_op.h"
23+
#include "paddle/operators/scatter.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
using Scope = framework::Scope;
29+
using Variable = framework::Variable;
30+
using Tensor = framework::Tensor;
31+
using LoDTensor = framework::LoDTensor;
32+
using DDim = framework::DDim;
33+
34+
void CondOp::CreateScope(const Scope& scope) const {
35+
auto sub_scopes_var = scope.FindVar("SubScopes");
36+
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
37+
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
38+
auto& sub_scope = scope.NewScope();
39+
sub_scopes->push_back(&sub_scope);
40+
}
41+
42+
void CondOp::CreateIndexTensor(const Scope& scope) const {
43+
auto index_tensors_var = scope.FindVar("IndexTensors");
44+
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
45+
auto& index_tensors =
46+
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
47+
index_tensors.push_back(LoDTensor());
48+
}
49+
50+
void CondOp::InferShape(const Scope& scope) const {
51+
auto sub_scopes_var = scope.FindVar("SubScopes");
52+
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
53+
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
54+
55+
for (int i = 0; i < 2; ++i) {
56+
// Create two sub scopes for true and false branches
57+
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
58+
// branch
59+
CreateScope(scope);
60+
61+
// Create two tensors for true and false indices
62+
// index_tensors[0] for the true branch and index_tensors[1] for the false
63+
// branch
64+
CreateIndexTensor(scope);
65+
66+
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
67+
for (auto& input : Inputs("Xs")) {
68+
// Create a new tensor in sub-scope for input-type tensor
69+
Variable* v = sub_scopes[i]->NewVar(input);
70+
LoDTensor* sub_input = v->GetMutable<LoDTensor>();
71+
sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
72+
}
73+
74+
for (auto& output : (*sub_net_op_[i]).Outputs()) {
75+
for (auto& var_name : output.second) {
76+
sub_scopes[i]->NewVar(var_name);
77+
}
78+
}
79+
80+
// each net calls InferShape
81+
sub_net_op_[i]->InferShape(*sub_scopes[i]);
82+
}
83+
84+
for (auto& output : Outputs("Outs")) {
85+
LoDTensor* tensor_t_out =
86+
sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
87+
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
88+
LoDTensor* tensor_f_out =
89+
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
90+
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");
91+
92+
auto* tensor_out_var = scope.FindVar(output);
93+
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
94+
LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
95+
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
96+
"True output tensor should not be NULL");
97+
98+
// check output size should be same
99+
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
100+
"Outputs not of the same shape");
101+
tensor_out->Resize(tensor_t_out->dims());
102+
// tensor_out->mutable_data<float>(tensor_out->dims(),
103+
// platform::CPUPlace());
104+
tensor_out->mutable_data<float>(platform::CPUPlace());
105+
}
106+
}
107+
108+
void CondOp::Run(const Scope& scope,
109+
const platform::DeviceContext& dev_ctx) const {
110+
auto* sub_scopes_var = scope.FindVar("SubScopes");
111+
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
112+
auto* index_tensors_var = scope.FindVar("IndexTensors");
113+
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
114+
115+
std::string cond_name = Input("Cond");
116+
Variable* cond_var = scope.FindVar(cond_name);
117+
PADDLE_ENFORCE_NOT_NULL(cond_var);
118+
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();
119+
120+
// Step 1: get the true/false index at runtime
121+
// index_[0]: vector<int>, contains all index for cond[i] == true
122+
// index_[1]: vector<int>, contains all index for cond[i] == false
123+
for (int i = 0; i < 2; ++i) index_[i].clear();
124+
125+
const int* cond_data = cond->data<int>();
126+
for (int i = 0; i < cond->dims()[0]; ++i) {
127+
if (cond_data[i])
128+
index_[0].push_back(i);
129+
else
130+
index_[1].push_back(i);
131+
}
132+
133+
// put index_[0] and index_[1] into two tensors:
134+
// index_tensor_[0] and index_tensor_[1]
135+
DDim dim = paddle::framework::make_ddim({0});
136+
for (int i = 0; i < 2; ++i) {
137+
dim[0] = index_[i].size();
138+
int* tmp_ptr =
139+
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
140+
index_tensors[i].Resize(dim);
141+
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
142+
}
143+
144+
// Step 2: collect data by calling gather
145+
for (int i = 0; i < 2; ++i) {
146+
// i= 0/i for True and False branches respectively
147+
for (auto& input : Inputs("Xs")) {
148+
// find Tensor
149+
Variable* v = scope.FindVar(input);
150+
PADDLE_ENFORCE_NOT_NULL(v);
151+
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
152+
153+
v = sub_scopes[i]->FindVar(input);
154+
PADDLE_ENFORCE_NOT_NULL(v);
155+
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
156+
157+
// Resize child
158+
DDim dim = tensor_child->dims();
159+
dim[0] = index_[i].size();
160+
tensor_child->Resize(dim);
161+
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
162+
163+
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
164+
tensor_child);
165+
}
166+
}
167+
168+
// Step 3: run
169+
for (int i = 0; i < 2; ++i) {
170+
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
171+
}
172+
173+
// Step 4: merge output results
174+
for (int i = 0; i < 2; ++i) {
175+
// i= 0/i for True and False branches respectively
176+
for (auto& output : Outputs("Outs")) {
177+
// find Tensor
178+
Variable* v = scope.FindVar(output);
179+
PADDLE_ENFORCE_NOT_NULL(v);
180+
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
181+
182+
v = sub_scopes[i]->FindVar(output);
183+
PADDLE_ENFORCE_NOT_NULL(v);
184+
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
185+
186+
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
187+
tensor_parent);
188+
}
189+
}
190+
}
191+
192+
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
193+
public:
194+
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
195+
framework::OpAttrChecker* op_checker)
196+
: OpProtoAndCheckerMaker(proto, op_checker) {
197+
AddInput("Cond", "The condition, which is a bool vector");
198+
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
199+
AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable();
200+
201+
AddOutput("SubScopes", "sub scopes for true and false branches");
202+
AddOutput("IndexTensors", "Index Tensors contains indices for true/false");
203+
204+
AddComment(R"DOC(
205+
Sample dependent Cond Operator:
206+
Given Cond[i] as a 1/0 vector to indicate true/false
207+
The equation is:
208+
Out[i] = subnet_t[i], if Cond[i] == true
209+
Out[i] = subnet_t[i], if Cond[i] == false
210+
)DOC");
211+
}
212+
};
213+
214+
} // namespace operators
215+
} // namespace paddle
216+
217+
REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp,
218+
paddle::operators::CondOpProtoAndCheckerMaker);

paddle/operators/cond_op.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <vector>
17+
#include "glog/logging.h"
18+
#include "paddle/framework/ddim.h"
19+
#include "paddle/framework/eigen.h"
20+
#include "paddle/framework/operator.h"
21+
#include "paddle/framework/tensor.h"
22+
#include "paddle/operators/net_op.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
27+
/*
28+
* @brief CondOp is a dynamic if-else Operator
29+
*
30+
* It has a input tensor named cond indicating which netop each instance will
31+
* run.
32+
*
33+
* if cond == 1, it will run true_net, which is a NetOp.
34+
*
35+
* if cond == 0, it will run false_net, which is another NetOp.
36+
*/
37+
class CondOp : public framework::OperatorBase {
38+
public:
39+
CondOp(const std::string& type, const framework::VariableNameMap& inputs,
40+
const framework::VariableNameMap& outputs,
41+
const framework::AttributeMap& attrs)
42+
: OperatorBase(type, inputs, outputs, attrs) {
43+
index_.resize(2);
44+
sub_net_op_.resize(2);
45+
}
46+
47+
CondOp(const CondOp& o)
48+
: framework::OperatorBase(
49+
static_cast<const framework::OperatorBase&>(o)) {
50+
// TODO(yuyang18): Implement copy ctor well.
51+
PADDLE_THROW("Not implemented");
52+
}
53+
54+
void CreateScope(const framework::Scope& scope) const;
55+
56+
void CreateIndexTensor(const framework::Scope& scope) const;
57+
58+
/*
59+
* InferShape must be called before Run.
60+
*/
61+
void InferShape(const framework::Scope& scope) const override;
62+
63+
/*
64+
* Set True Block
65+
*/
66+
void set_truenet(std::unique_ptr<OperatorBase>&& net) {
67+
sub_net_op_[0] = std::move(net);
68+
}
69+
70+
/*
71+
* Set False Block
72+
*/
73+
void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
74+
sub_net_op_[1] = std::move(net);
75+
}
76+
77+
void Run(const framework::Scope& scope,
78+
const platform::DeviceContext& dev_ctx) const override;
79+
80+
private:
81+
// sub_net_op_[0]: subnet_t
82+
// sub_net_op_[1]: subnet_f
83+
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;
84+
85+
// index_[0]: True_index;
86+
// index_[1]: False_index;
87+
mutable std::vector<std::vector<int>> index_;
88+
};
89+
90+
} // namespace operators
91+
} // namespace paddle

0 commit comments

Comments
 (0)