Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions doc/design/if_else_op.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,4 @@
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`.

```python
import paddle as pd

x = var()
y = var()
cond = var()

b = pd.create_ifop(inputs=[x], output_num=1)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))

out = b(cond)
```

If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has N instances. If cond[i] == True, input instance input[i] will go through true_block() and generate output[i]; otherwise it will produce output from false_bloack().

```python
import paddle as pd
Expand All @@ -39,7 +21,7 @@ with b.false_block():
out = b(cond)
```

If only true_block is set in an IfElseOp, we can have a default value for false as:
If only true_block is set in an IfElseOp, a special case is that we can have a default value for false as:
```python
import paddle as pd

Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
holder_->size(), numel() * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call "
"holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call "
"holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
Expand Down Expand Up @@ -274,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) {
Tensor res = ReshapeToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
}
4 changes: 3 additions & 1 deletion paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ endfunction()
add_subdirectory(math)

set(DEPS_OPS
recurrent_op)
recurrent_op
cond_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)

list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
Expand Down
218 changes: 218 additions & 0 deletions paddle/operators/cond_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/cond_op.h"

#include <cstring>
#include <sstream>

#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"

namespace paddle {
namespace operators {

using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;

void CondOp::CreateScope(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
auto& sub_scope = scope.NewScope();
sub_scopes->push_back(&sub_scope);
}

void CondOp::CreateIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
auto& index_tensors =
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
index_tensors.push_back(LoDTensor());
}

void CondOp::InferShape(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();

for (int i = 0; i < 2; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope(scope);

// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor(scope);

PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input);
LoDTensor* sub_input = v->GetMutable<LoDTensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
}

for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}

// each net calls InferShape
sub_net_op_[i]->InferShape(*sub_scopes[i]);
}

for (auto& output : Outputs("Outs")) {
LoDTensor* tensor_t_out =
sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
LoDTensor* tensor_f_out =
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");

auto* tensor_out_var = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
"True output tensor should not be NULL");

// check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims());
// tensor_out->mutable_data<float>(tensor_out->dims(),
// platform::CPUPlace());
tensor_out->mutable_data<float>(platform::CPUPlace());
}
}

void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto* sub_scopes_var = scope.FindVar("SubScopes");
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
auto* index_tensors_var = scope.FindVar("IndexTensors");
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();

std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name);
PADDLE_ENFORCE_NOT_NULL(cond_var);
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();

// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for (int i = 0; i < 2; ++i) index_[i].clear();

const int* cond_data = cond->data<int>();
for (int i = 0; i < cond->dims()[0]; ++i) {
if (cond_data[i])
index_[0].push_back(i);
else
index_[1].push_back(i);
}

// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim dim = paddle::framework::make_ddim({0});
for (int i = 0; i < 2; ++i) {
dim[0] = index_[i].size();
int* tmp_ptr =
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
index_tensors[i].Resize(dim);
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
}

// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : Inputs("Xs")) {
// find Tensor
Variable* v = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();

v = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();

// Resize child
DDim dim = tensor_child->dims();
dim[0] = index_[i].size();
tensor_child->Resize(dim);
tensor_child->mutable_data<float>(dim, platform::CPUPlace());

Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
Copy link
Contributor

@qingqing01 qingqing01 Sep 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recurrent operator also uses explicit type float in the implementation. CondOp has the same problem. Maybe the RecurrentOp and CondOp should like:

template <T>
class CondOp : public framework::OperatorBase {
}

Then specialized the class.

}
}

// Step 3: run
for (int i = 0; i < 2; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
}

// Step 4: merge output results
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& output : Outputs("Outs")) {
// find Tensor
Variable* v = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();

v = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();

ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent);
}
}
}

class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable();

AddOutput("SubScopes", "sub scopes for true and false branches");
AddOutput("IndexTensors", "Index Tensors contains indices for true/false");

AddComment(R"DOC(
Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false
The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp,
paddle::operators::CondOpProtoAndCheckerMaker);
91 changes: 91 additions & 0 deletions paddle/operators/cond_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <vector>
#include "glog/logging.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/net_op.h"

namespace paddle {
namespace operators {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some doc here.

/*
* @brief CondOp is a dynamic if-else Operator
*
* It has a input tensor named cond indicating which netop each instance will
* run.
*
* if cond == 1, it will run true_net, which is a NetOp.
*
* if cond == 0, it will run false_net, which is another NetOp.
*/
class CondOp : public framework::OperatorBase {
public:
CondOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
index_.resize(2);
sub_net_op_.resize(2);
}

CondOp(const CondOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}

void CreateScope(const framework::Scope& scope) const;

void CreateIndexTensor(const framework::Scope& scope) const;

/*
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override;

/*
* Set True Block
*/
void set_truenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[0] = std::move(net);
}

/*
* Set False Block
*/
void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[1] = std::move(net);
}

void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;

private:
// sub_net_op_[0]: subnet_t
// sub_net_op_[1]: subnet_f
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;

// index_[0]: True_index;
// index_[1]: False_index;
mutable std::vector<std::vector<int>> index_;
};

} // namespace operators
} // namespace paddle
Loading