Skip to content

Conversation

@zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Dec 5, 2023

PR types

New features

PR changes

Others

Description

为了保证cond 算子翻译的成功率,本 PR对翻译策略进行优化,采用的原则是:

  • 遇到 conditional_block_op,直接翻译成一个对应的 if_op,该 if_op 将只有 true 分支,没有 false 分支

本PR主要内容包括:

  • 实现新的翻译策略;
  • cond_instruction 重命名为 if_instruction
  • if_op 支持仅有 true block(有2个 region,false region 的 block 为空),适配上下游组件,包括:if op verify 检查、lower 对 false 分支的处理、if_instruction 中对 false 分支的处理
  • 支持 select_input 算子,同时适配下游内容:翻译、lower、执行(该算子暂时不提供对外组网 api)

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Dec 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
(*this)->region(0).size()));
auto &true_last_op = (*this)->region(0).front().back();
PADDLE_ENFORCE_EQ(true,
Copy link
Contributor

@winter-wang winter-wang Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿应该考虑一下if没有返回值的情况。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks~

winter-wang
winter-wang previously approved these changes Dec 8, 2023
Copy link
Contributor

@winter-wang winter-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, commet 单独提PR fix

private:
void copy_tensor(const phi::DenseTensor &lod_tensor,
phi::DenseTensor *out) const {
if (!lod_tensor.IsInitialized()) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么直接return了,而不是Throw Error?


static std::vector<int64_t> ParseCompatibleShapes(
const std::vector<int64_t>& dim1, const std::vector<int64_t>& dim2) {
IR_ENFORCE(dim1.size() == dim2.size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是可以用PADDLE_ENFORCE?另外想了解下,IR_ENFORCE 抛出的异常显示的python callstack 与框架目前的栈有什么差异,会被pybind层Exception正确映射么?如果能的话,那这里就无所谓了

auto op_info = this->LoopkUpOpInfo(ctx, op_desc);

std::vector<pir::Value> op_inputs = {};
auto Mask_name = op_desc.Input("Mask")[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto Mask_name = op_desc.Input("Mask")[0];
auto mask_name = op_desc.Input("Mask")[0];

这里命名首字母不需要大写?

op_desc.Type(),
Mask_name);
op_inputs.push_back(param_map->at(Mask_name).value);
for (auto in_name : Input_name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto in_name : Input_name) {
for (auto& in_name : Input_name) {


OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types;
auto Out_name = op_desc.Output("Out")[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto Out_name = op_desc.Output("Out")[0];
auto& Out_name = op_desc.Output("Out")[0];

array_op.out().set_type(type);
return array_op.operation();
}
return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接Throw NotImplement 会不会更好一些?这样下游不用检查这个函数的返回是否有效

for (auto input_name : input_names) {
auto cond_op_cond = op->Input("Cond")[0];
auto& cond_op_inputs = op->Input("Input");
for (auto input_name : cond_op_inputs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto input_name : cond_op_inputs) {
for (auto& input_name : cond_op_inputs) {

tensor1.dtype(),
tensor2.dtype());
IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(),
"The 1st input data_layout %s should be equal to 2ed input "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The 1st input data_layout %s should be equal to 2ed input "
"The 1st input data_layout %s should be equal to 2nd input "

Copy link
Contributor

@winter-wang winter-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhangbo9674 zhangbo9674 merged commit e5fbff4 into PaddlePaddle:develop Dec 11, 2023
Comment on lines 130 to +133
false_branch_inter_ =
new PirInterpreter(place,
{},
&false_branch_block,
&if_op.false_block(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议改用make_unique

Comment on lines +62 to +64
PirInterpreter* true_branch_inter_ = nullptr;

PirInterpreter* false_branch_inter_;
PirInterpreter* false_branch_inter_ = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换用unique_ptr表明所有权更好吧

pir::Block* block) override {
VLOG(10) << "[op select_input] start transcribing";
auto op_info = this->LoopkUpOpInfo(ctx, op_desc);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在这里添加一句

Suggested change
this->InsertSliceOperationForInput(ctx, param_map, op_desc, input_infos, block);

保证输入的类型正确

Comment on lines +411 to +412
pir::Value mask() { return operand_source(0); }
pir::OpResult out() { return result(0); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了input的相关接口

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants