-
Notifications
You must be signed in to change notification settings - Fork 5.9k
repair npu matmulv2_grad and comm_init_hccl #33719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy}, | ||
| {{"adj_x1", true}, {"adj_x2", false}}); | ||
| runner_dy.Run(stream); | ||
| if ((x->dims().size() == 3) && (dout->dims().size() == 3) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x dims为3,y dims为2的情况,前向是不是也不能用BatchMatMul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前向可以,这里做了纬度判断是因为输出是个2纬,但是输入是两个3纬需要转化下
pangyoki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmul v2 op fp16类型可能存在输入 3维 乘 2维的情况。BatchMatMul NPU op的fp32类型不支持这种情况。
目前情况下不会使用fp32数据类型,输入 3维 乘 2维的情况。所以暂时没对fp32做支持。
后续需要添加fp32类型对这种情况的处理。
wanghuancoder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for unittest.skipIf
|
|
||
| // Build comm | ||
| float* buff; | ||
| int32_t size = 20; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥是20?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
仅用于初始化
| for (int32_t idx = 0; idx < size; idx++) { | ||
| input[idx] = 1.0; | ||
| } | ||
| aclrtMalloc(reinterpret_cast<void**>(&buff), size * sizeof(float), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种函数需要确保成功吧,得拿ACLCHECK包一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记录,下一个pr优化进去
PR types
Bug fixes
PR changes
OPs
Describe
1.repair npu matmulv2_grad supported 3*3->2 and add the UT test.
2.repair npu comm_init_hccl op by adding to send fake data to build connection.
matmul_gradv2 precision npu and gpu in fp16 for 5 epochs.
npu:

gpu:
