Skip to content

Comments

Cond_op with dynamic if-else checked-in#4088

Merged
zchen0211 merged 19 commits intoPaddlePaddle:developfrom
zchen0211:develop
Sep 14, 2017
Merged

Cond_op with dynamic if-else checked-in#4088
zchen0211 merged 19 commits intoPaddlePaddle:developfrom
zchen0211:develop

Conversation

@zchen0211
Copy link
Contributor

Implemented the dynamic condition (if/else) op. The result matches my python results.
TODO: Will implement the backward part soon.

inline T* mutable_data(DDim dims, platform::Place place);

/*! Size of a single element in data() */
inline size_t element_size() const { return holder_->element_size(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think there is no need to add this function. A function like Variable::IsType() is more useful: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L41

tensor_child->mutable_data<float>(dim, platform::CPUPlace());

Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
Copy link
Contributor

@qingqing01 qingqing01 Sep 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recurrent operator also uses explicit type float in the implementation. CondOp has the same problem. Maybe the RecurrentOp and CondOp should like:

template <T>
class CondOp : public framework::OperatorBase {
}

Then specialized the class.


for (auto& output : Outputs("Outs")) {
Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be NULL

Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL");
Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as top


void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enforce scope.FindVar("SubScopes") is not null first
then Get.

const platform::DeviceContext& dev_ctx) const {
auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>();
auto index_tensors =
scope.FindVar("IndexTensors")->Get<std::vector<Tensor>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

}

// Step 3: run
for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add { }
every cond operator should be wrapped with {}


AddComment(R"DOC(
Sample dependent Cond Operator:
The equation is: Out[i] = subnet_t[i], if Cond[i] == true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some indent here

*/
void InferShape(const framework::Scope& scope) const override;

// Set True Block
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment format should be unified.

use

/*
 * some comment
 */

}

// Set False Block
void set_falsenet(std::unique_ptr<OperatorBase> net) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::unique_ptr &&net)

global reference


namespace paddle {
namespace operators {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some doc here.

*
* if cond == 0, it will run false_net, which is another NetOp.
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no empty line here

@zchen0211 zchen0211 merged commit 4c7a9a4 into PaddlePaddle:develop Sep 14, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants