Skip to content

[AutoParallel] Add take_along_axis spmd rules#72063

Merged
pkuzyc merged 6 commits intoPaddlePaddle:developfrom
NKNaN:take_along_axis-spmd
Jun 30, 2025
Merged

[AutoParallel] Add take_along_axis spmd rules#72063
pkuzyc merged 6 commits intoPaddlePaddle:developfrom
NKNaN:take_along_axis-spmd

Conversation

@NKNaN
Copy link
Contributor

@NKNaN NKNaN commented Apr 3, 2025

PR Category

Auto Parallel

PR Types

New features

Description

take_along_axis 增加 spmd rule

@paddle-bot
Copy link

paddle-bot bot commented Apr 3, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 3, 2025
@NKNaN NKNaN changed the title Add take_along_axis spmd rule [AutoParallel] Add take_along_axis spmd rules Apr 7, 2025
@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from 7b76fbc to 14e7b85 Compare April 7, 2025 12:50
// Step2: Sharding Propagation
// Step2.1: Merge input shardings
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
if (x_dims_mapping[axis] != -1) x_dims_mapping[axis] = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用判断,直接赋值就行。不能切的维度在 einsum notation 里也可以设置成 '1',可以看下根据 notation 推导的逻辑,有处理 '1' 的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

if (x_dims_mapping[axis] != -1) x_dims_mapping[axis] = -1;

std::vector<int64_t> index_dims_mapping(index_dims_mapping_src);
for (int i = 0; i < index_ndim; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index 和 x 的 dims_mapping 是否也可以用 einsum notation 推出来。现在这样似乎没有考虑 index 的切分状态对 x 的影响?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦我这样写确实只是在根据 x 的 dims_mapping 设置 index 的 dims_mapping,那应该这里的

if (index_dims_mapping[i] != x_dims_mapping[i])
    index_dims_mapping[i] = x_dims_mapping[i];

需要去掉。
x 和 index 的 einsum notation 的关系应该是:除了 axis 维的其他维都是一一对应的,x 的 axis 维可以设成 ‘1’。这样后面用 ShardingMergeForTensors 来帮助处理一些设置上的冲突情况?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShardingMergeForTensors 会处理同一个 notation 标记有两个切分状态的情况,如果有切和不切两种情况,会优先用切

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯好的,那我再改一下

index_ndim));

// Step1: Build Einsum Notation
// e.g. axis=1, x: azc, index: abc, out: abc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 out 的 notation 设置是不是有问题?下面这种情况:
x: [0, -1, -1]
index:[0, -1, -1]
axis=0
按现在的设置会推出 out 是 [0,-1,-1],这结果不对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out 的切分状态应该是和 index 一样的,因为从逐元素遍历的计算公式来看 out 和 index 的下标有一一对应的关系,
out[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0

x: [0, -1, -1]
index:[0, -1, -1]
axis=0
这种情况推出来 x 是 [-1, -1, -1],index 和 out 是 [0, -1, -1],相当于这个例子:
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我想错了,x 是 [-1, -1, -1],index 和 out 是 [0, -1, -1] 是对的

EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);

// Step1: Build Einsum Notation
// e.g. axis=1, out_grad: abc -> x: azc, index: abc, x_grad: azc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据前向相应修改

@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from 16a65c7 to fb3dcf3 Compare April 16, 2025 12:57
std::string x_axes = index_axes;
x_axes.replace(axis, 1, "1");
for (int i = 0; i < index_ndim; ++i) {
if (i != axis && x_shape[i] != index_shape[i]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断 shape 不一样是处理广播的情况?如果只有 shape=1 会广播的话用 shape=1 当条件清楚一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里处理的不只是广播的情况,只要非 axis 维 x和index 形状不一致应该都不能切分,比如这个例子:
image

@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from fb3dcf3 to 9ecf373 Compare April 29, 2025 13:44
@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from 976aaae to 6f14409 Compare May 7, 2025 02:53
@paddle-ci-bot
Copy link

paddle-ci-bot bot commented May 18, 2025

Sorry to inform you that 6f14409's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@NKNaN NKNaN force-pushed the take_along_axis-spmd branch 2 times, most recently from 88e0446 to 45f4445 Compare June 12, 2025 01:40
@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from 45f4445 to 355eea6 Compare June 24, 2025 01:16
@NKNaN NKNaN force-pushed the take_along_axis-spmd branch from 355eea6 to f099f98 Compare June 27, 2025 07:09
@NKNaN
Copy link
Contributor Author

NKNaN commented Jun 27, 2025

/re-run all-failed

Copy link
Contributor

@pkuzyc pkuzyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pkuzyc pkuzyc merged commit ecd685a into PaddlePaddle:develop Jun 30, 2025
92 of 101 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants