Skip to content

[NPU] add reduce_max#34179

Merged
qili93 merged 5 commits intoPaddlePaddle:developfrom
windstamp:npu_dev_20210714
Aug 2, 2021
Merged

[NPU] add reduce_max#34179
qili93 merged 5 commits intoPaddlePaddle:developfrom
windstamp:npu_dev_20210714

Conversation

@windstamp
Copy link
Contributor

@windstamp windstamp commented Jul 15, 2021

PR types

New features

PR changes

OPs

Describe

add reduce_max

image

68aa1c466a8a85c61ead9d9d42dbd63d

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

auto* out = ctx.Output<Tensor>("Out");
auto dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需增加reduce_all, out_dtype 存在时的PADDEL_ENFORCE报错信息。或者增加NPU算子进行支持。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks.

增加了对 attr reduce_all 的支持,增加了对attr in_dtype 和 out_dtype 必须为默认值的 PADDEL_ENFORCE 报错信息。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以看下reduce_op.h里面对于in_dtype的说明,这个无法被用户设置,可以不用检查;然后out_dtype其实可以尝试加一个cast算子进行数据类型的转换

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks.

self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以参考test_reduce_max_op_xpu.py增加reduce_all和out_dtype的单测,如果在C++部分支持reduce_all的输入的话。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks.

增加了对 attr reduce_all 的单测。

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for shareDataWith

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skip_check_grad_ci

@qili93 qili93 merged commit de53f2b into PaddlePaddle:develop Aug 2, 2021
@windstamp windstamp deleted the npu_dev_20210714 branch August 2, 2021 08:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants