[AutoParallel] Add take_along_axis spmd rules#72063
[AutoParallel] Add take_along_axis spmd rules#72063pkuzyc merged 6 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
take_along_axis spmd rule7b76fbc to
14e7b85
Compare
| // Step2: Sharding Propagation | ||
| // Step2.1: Merge input shardings | ||
| std::vector<int64_t> x_dims_mapping(x_dims_mapping_src); | ||
| if (x_dims_mapping[axis] != -1) x_dims_mapping[axis] = -1; |
There was a problem hiding this comment.
不用判断,直接赋值就行。不能切的维度在 einsum notation 里也可以设置成 '1',可以看下根据 notation 推导的逻辑,有处理 '1' 的情况
| if (x_dims_mapping[axis] != -1) x_dims_mapping[axis] = -1; | ||
|
|
||
| std::vector<int64_t> index_dims_mapping(index_dims_mapping_src); | ||
| for (int i = 0; i < index_ndim; ++i) { |
There was a problem hiding this comment.
index 和 x 的 dims_mapping 是否也可以用 einsum notation 推出来。现在这样似乎没有考虑 index 的切分状态对 x 的影响?
There was a problem hiding this comment.
哦我这样写确实只是在根据 x 的 dims_mapping 设置 index 的 dims_mapping,那应该这里的
if (index_dims_mapping[i] != x_dims_mapping[i])
index_dims_mapping[i] = x_dims_mapping[i];需要去掉。
x 和 index 的 einsum notation 的关系应该是:除了 axis 维的其他维都是一一对应的,x 的 axis 维可以设成 ‘1’。这样后面用 ShardingMergeForTensors 来帮助处理一些设置上的冲突情况?
There was a problem hiding this comment.
ShardingMergeForTensors 会处理同一个 notation 标记有两个切分状态的情况,如果有切和不切两种情况,会优先用切
| index_ndim)); | ||
|
|
||
| // Step1: Build Einsum Notation | ||
| // e.g. axis=1, x: azc, index: abc, out: abc |
There was a problem hiding this comment.
这里 out 的 notation 设置是不是有问题?下面这种情况:
x: [0, -1, -1]
index:[0, -1, -1]
axis=0
按现在的设置会推出 out 是 [0,-1,-1],这结果不对
There was a problem hiding this comment.
这里我想错了,x 是 [-1, -1, -1],index 和 out 是 [0, -1, -1] 是对的
| EXTRACT_SHAPE_AND_DIST_ATTR(out_grad); | ||
|
|
||
| // Step1: Build Einsum Notation | ||
| // e.g. axis=1, out_grad: abc -> x: azc, index: abc, x_grad: azc |
16a65c7 to
fb3dcf3
Compare
| std::string x_axes = index_axes; | ||
| x_axes.replace(axis, 1, "1"); | ||
| for (int i = 0; i < index_ndim; ++i) { | ||
| if (i != axis && x_shape[i] != index_shape[i]) { |
There was a problem hiding this comment.
判断 shape 不一样是处理广播的情况?如果只有 shape=1 会广播的话用 shape=1 当条件清楚一些
fb3dcf3 to
9ecf373
Compare
976aaae to
6f14409
Compare
|
Sorry to inform you that 6f14409's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
88e0446 to
45f4445
Compare
45f4445 to
355eea6
Compare
355eea6 to
f099f98
Compare
|
/re-run all-failed |


PR Category
Auto Parallel
PR Types
New features
Description
为
take_along_axis增加 spmd rule