Skip to content

Conversation

@DongBaiYue
Copy link
Contributor

PR Category

CINN

PR Types

Improvements

Description

使用target.arch.Match替换DefaultNVGPUTarget分支。

pcard-79890

@paddle-bot
Copy link

paddle-bot bot commented May 24, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 24, 2024
};
target.arch.Match(
[&](common::NVGPUArch) {
if (!FLAGS_cinn_enable_map_expr && !FLAGS_cinn_new_group_scheduler) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以尝试提炼这些代码到单独的闭包里,比如叫NvGpuArchCompute。这样缩进更好看些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

sch->Bind(rb_loops.back(), "threadIdx.x");
sch->SetBuffer(rf_block, "local");
},
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在函数体里显式写注释:// Do nothing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

bytes,
cudaMemcpyDeviceToDevice,
static_cast<cudaStream_t>(stream));
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些语义没有无脑地维持,虽然实质上差不多。
本函数的重构应该这样弄:

return input_target.arch.Match(
  [&](...) {
    ...
    return true;
  }
)

每个case提供一个返回值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的。不过如果原有代码只在部分case下有返回值,会难以处理。

[&](common::X86Arch) {
std::copy(tensor->data<T>(), tensor->data<T>() + size, data.begin());
},
[&](std::variant<common::UnknownArch, common::ARMArch>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

【只是建议】可以不可以不用std::variant做case?总感觉那样的代码层次更少更清爽。如果觉得有些地方会导致代码冗余,我们总是可以借助lambda提炼公共代码。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉可读性差不多,std::variant更强调几种case执行共性操作

@tc20042008 tc20042008 merged commit 5afb0b5 into PaddlePaddle:develop May 29, 2024
@DongBaiYue DongBaiYue deleted the modifyTargetIfElse branch July 24, 2024 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants