Cond_op with dynamic if-else checked-in#4088
Cond_op with dynamic if-else checked-in#4088zchen0211 merged 19 commits intoPaddlePaddle:developfrom
Conversation
Fix CI test
paddle/framework/tensor.h
Outdated
| inline T* mutable_data(DDim dims, platform::Place place); | ||
|
|
||
| /*! Size of a single element in data() */ | ||
| inline size_t element_size() const { return holder_->element_size(); } |
There was a problem hiding this comment.
Actually, I think there is no need to add this function. A function like Variable::IsType() is more useful: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L41
| tensor_child->mutable_data<float>(dim, platform::CPUPlace()); | ||
|
|
||
| Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], | ||
| tensor_child); |
There was a problem hiding this comment.
The recurrent operator also uses explicit type float in the implementation. CondOp has the same problem. Maybe the RecurrentOp and CondOp should like:
template <T>
class CondOp : public framework::OperatorBase {
}Then specialized the class.
paddle/operators/cond_op.cc
Outdated
|
|
||
| for (auto& output : Outputs("Outs")) { | ||
| Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); |
paddle/operators/cond_op.cc
Outdated
| Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); | ||
| Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable<Tensor>(); | ||
| PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL"); |
paddle/operators/cond_op.cc
Outdated
|
|
||
| void CondOp::Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const { | ||
| auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>(); |
There was a problem hiding this comment.
enforce scope.FindVar("SubScopes") is not null first
then Get.
paddle/operators/cond_op.cc
Outdated
| const platform::DeviceContext& dev_ctx) const { | ||
| auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>(); | ||
| auto index_tensors = | ||
| scope.FindVar("IndexTensors")->Get<std::vector<Tensor>>(); |
paddle/operators/cond_op.cc
Outdated
| } | ||
|
|
||
| // Step 3: run | ||
| for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); |
There was a problem hiding this comment.
add { }
every cond operator should be wrapped with {}
paddle/operators/cond_op.cc
Outdated
|
|
||
| AddComment(R"DOC( | ||
| Sample dependent Cond Operator: | ||
| The equation is: Out[i] = subnet_t[i], if Cond[i] == true |
paddle/operators/cond_op.h
Outdated
| */ | ||
| void InferShape(const framework::Scope& scope) const override; | ||
|
|
||
| // Set True Block |
There was a problem hiding this comment.
comment format should be unified.
use
/*
* some comment
*/
paddle/operators/cond_op.h
Outdated
| } | ||
|
|
||
| // Set False Block | ||
| void set_falsenet(std::unique_ptr<OperatorBase> net) { |
There was a problem hiding this comment.
std::unique_ptr &&net)
global reference
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
paddle/operators/cond_op.h
Outdated
| * | ||
| * if cond == 0, it will run false_net, which is another NetOp. | ||
| */ | ||
|
|
Implemented the dynamic condition (if/else) op. The result matches my python results.
TODO: Will implement the backward part soon.