Skip to content

[CherryPick] fix compare ops when broadcast #33086

Merged
XiaoguangHu01 merged 2 commits intoPaddlePaddle:release/2.1from
wawltor:cp_fix_compare_ops_broadcast
Jun 4, 2021
Merged

[CherryPick] fix compare ops when broadcast #33086
XiaoguangHu01 merged 2 commits intoPaddlePaddle:release/2.1from
wawltor:cp_fix_compare_ops_broadcast

Conversation

@wawltor
Copy link
Contributor

@wawltor wawltor commented May 24, 2021

PR types

Bug fixes

PR changes

OPs

Describe

Cp fix compare ops broadcast,cherry-pick from the PR #32941

def sequence_mask(x_len, max_len=None, dtype='float32'):    
       max_len = max_len or x_len.max()   
       x_len = paddle.unsqueeze(x_len, -1)   
       row_vector = paddle.arange(max_len)   
       mask = row_vector < x_len    
       mask = paddle.cast(mask, dtype)    
       return mask

现象:上面是代码复现的BUG代码,当时比较类API进行broadcast过程中,如果第一个输入的shape小于第二个shape的维度,
在一定条件下会触发一个bug,例如 进行 < 比较的时候,出现了相同的值则会比较错误。

原因:注册Kernel的时候函数注册错误

解决:注册正确的比较类函数

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 24, 2021

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Jun 4, 2021
@PaddlePaddle PaddlePaddle unlocked this conversation Jun 4, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XiaoguangHu01 XiaoguangHu01 merged commit c42ccf1 into PaddlePaddle:release/2.1 Jun 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants