Skip to content

Commit ecd685a

Browse files
authored
[AutoParallel] Add take_along_axis spmd rules (#72063)
* add take_along_axis spmd rule * update cmakelists * update rule * update rules and test * update rules and test * fix test
1 parent 5090b58 commit ecd685a

8 files changed

Lines changed: 612 additions & 0 deletions

File tree

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,4 +805,9 @@ PD_REGISTER_SPMD_RULE(
805805
fused_gemm_epilogue,
806806
PD_INFER_SPMD(phi::distributed::FusedGemmEpilogueInferSpmdBase));
807807

808+
// take_along_axis
809+
PD_REGISTER_SPMD_RULE(
810+
take_along_axis,
811+
PD_INFER_SPMD(phi::distributed::TakeAlongAxisInferSpmd),
812+
PD_INFER_SPMD(phi::distributed::TakeAlongAxisGradInferSpmd));
808813
} // namespace phi::distributed

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ limitations under the License. */
8080
#include "paddle/phi/infermeta/spmd_rules/squared_l2_norm.h"
8181
#include "paddle/phi/infermeta/spmd_rules/squeeze.h"
8282
#include "paddle/phi/infermeta/spmd_rules/stack.h"
83+
#include "paddle/phi/infermeta/spmd_rules/take_along_axis.h"
8384
#include "paddle/phi/infermeta/spmd_rules/tile.h"
8485
#include "paddle/phi/infermeta/spmd_rules/topk.h"
8586
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/take_along_axis.h"
16+
17+
#include "glog/logging.h"
18+
19+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
20+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
21+
22+
namespace phi::distributed {
23+
SpmdInfo TakeAlongAxisInferSpmd(const DistMetaTensor& x,
24+
const DistMetaTensor& index,
25+
int axis) {
26+
/*
27+
gather computation formula:
28+
29+
out[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
30+
out[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
31+
out[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
32+
*/
33+
34+
// Deduced spmd rule:
35+
// x: cannot be sharded on `axis` dim;
36+
// index: the `axis` dim could be either sharded or not, other dimension
37+
// should be the same as x;
38+
// out: same as index;
39+
// For non-`axis` dim, if the sizes of this dim in x and index are not
40+
// the same, this dim should not be sharded.
41+
42+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
43+
EXTRACT_SHAPE_AND_DIST_ATTR(index);
44+
PADDLE_ENFORCE_EQ(x_ndim,
45+
index_ndim,
46+
common::errors::InvalidArgument(
47+
"x and index must have the same number of dimensions "
48+
"but received x_ndim [%d], index_ndim [%d]",
49+
x_ndim,
50+
index_ndim));
51+
52+
// Step1: Build Einsum Notation
53+
// e.g. axis=1, x: a1c, index: abc, out: abc
54+
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
55+
std::string index_axes = GetBroadcastAxes(index_ndim, index_ndim, alphabet);
56+
std::string x_axes = index_axes;
57+
x_axes.replace(axis, 1, "1");
58+
for (int i = 0; i < index_ndim; ++i) {
59+
if (i != axis && x_shape[i] != index_shape[i]) {
60+
x_axes.replace(i, 1, "1");
61+
index_axes.replace(i, 1, "1");
62+
}
63+
}
64+
std::string out_axes = index_axes;
65+
66+
// Step2: Sharding Propagation
67+
// Step2.1: Merge input shardings
68+
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
69+
x_dims_mapping[axis] = -1;
70+
std::vector<int64_t> index_dims_mapping(index_dims_mapping_src);
71+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
72+
ShardingMergeForTensors(
73+
{{x_axes, x_dims_mapping}, {index_axes, index_dims_mapping}});
74+
75+
// Step2.2: Infer output dims mapping
76+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
77+
x_dist_attr_dst.set_dims_mapping(
78+
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
79+
80+
TensorDistAttr index_dist_attr_dst =
81+
CopyTensorDistAttrForOutput(index_dist_attr_src);
82+
index_dist_attr_dst.set_dims_mapping(
83+
GetDimsMappingForAxes(index_axes, axis_to_dim_map));
84+
85+
TensorDistAttr out_dist_attr =
86+
CopyTensorDistAttrForOutput(index_dist_attr_src);
87+
out_dist_attr.set_dims_mapping(
88+
GetDimsMappingForAxes(out_axes, axis_to_dim_map));
89+
90+
VLOG(4) << "x_axes: " << x_axes << " index_axes: " << index_axes
91+
<< " out_axes: " << out_axes;
92+
LOG_SPMD_INPUT(x);
93+
LOG_SPMD_INPUT(index);
94+
VLOG(4) << "out";
95+
VLOG(4) << "dist_attr: [" << out_dist_attr.to_string() << "]";
96+
return {{x_dist_attr_dst, index_dist_attr_dst}, {out_dist_attr}};
97+
}
98+
99+
SpmdInfo TakeAlongAxisGradInferSpmd(const DistMetaTensor& x,
100+
const DistMetaTensor& index,
101+
const DistMetaTensor& out_grad,
102+
int axis) {
103+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
104+
EXTRACT_SHAPE_AND_DIST_ATTR(index);
105+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
106+
107+
// Step1: Build Einsum Notation
108+
// e.g. axis=1, out_grad: abc -> x: a1c, index: abc, x_grad: a1c
109+
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
110+
std::string out_grad_axes =
111+
GetBroadcastAxes(out_grad_ndim, out_grad_ndim, alphabet);
112+
std::string index_axes = out_grad_axes;
113+
std::string x_axes = index_axes;
114+
x_axes.replace(axis, 1, "1");
115+
for (int i = 0; i < index_ndim; ++i) {
116+
if (i != axis && x_shape[i] != index_shape[i]) {
117+
x_axes.replace(i, 1, "1");
118+
index_axes.replace(i, 1, "1");
119+
out_grad_axes.replace(i, 1, "1");
120+
}
121+
}
122+
std::string x_grad_axes = x_axes;
123+
124+
// Step2: Sharding Propagation
125+
// Step2.1: Merge input shardings
126+
std::vector<int64_t> out_grad_dims_mapping(out_grad_dims_mapping_src);
127+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
128+
ShardingMergeForTensors({{out_grad_axes, out_grad_dims_mapping}});
129+
130+
// step2.2: Infer input dims mapping from merged input dims mapping
131+
std::vector<int64_t> index_dims_mapping =
132+
GetDimsMappingForAxes(index_axes, axis_to_dim_map);
133+
auto index_dist_attr_dst = CopyTensorDistAttrForOutput(index_dist_attr_src);
134+
index_dist_attr_dst.set_dims_mapping(index_dims_mapping);
135+
136+
auto out_grad_dist_attr_dst =
137+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
138+
out_grad_dist_attr_dst.set_dims_mapping(index_dims_mapping);
139+
140+
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
141+
x_dist_attr_dst.set_dims_mapping(
142+
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
143+
144+
auto x_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
145+
x_grad_dist_attr_dst.set_dims_mapping(
146+
GetDimsMappingForAxes(x_grad_axes, axis_to_dim_map));
147+
148+
VLOG(4) << "out_grad";
149+
VLOG(4) << "dist_attr: [" << out_grad_dist_attr_dst.to_string() << "]";
150+
VLOG(4) << "index";
151+
VLOG(4) << "dist_attr: [" << index_dist_attr_dst.to_string() << "]";
152+
VLOG(4) << "x";
153+
VLOG(4) << "dist_attr: [" << x_dist_attr_dst.to_string() << "]";
154+
VLOG(4) << "x_grad";
155+
VLOG(4) << "dist_attr: [" << x_grad_dist_attr_dst.to_string() << "]";
156+
157+
return {{x_dist_attr_dst, index_dist_attr_dst, out_grad_dist_attr_dst},
158+
{x_grad_dist_attr_dst}};
159+
}
160+
} // namespace phi::distributed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
SpmdInfo TakeAlongAxisInferSpmd(const DistMetaTensor& x,
23+
const DistMetaTensor& index,
24+
int axis);
25+
26+
SpmdInfo TakeAlongAxisGradInferSpmd(const DistMetaTensor& x,
27+
const DistMetaTensor& index,
28+
const DistMetaTensor& out_grad,
29+
int axis);
30+
31+
} // namespace distributed
32+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,7 @@
35773577
infer_meta :
35783578
func : UnchangedInferMeta
35793579
param : [arr]
3580+
spmd_rule : TakeAlongAxisGradInferSpmd
35803581
kernel :
35813582
func : take_along_axis_grad
35823583
backward : take_along_axis_double_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5229,6 +5229,7 @@
52295229
infer_meta :
52305230
func : TakeAlongAxisInferMeta
52315231
param : [arr, indices, axis]
5232+
spmd_rule : TakeAlongAxisInferSpmd
52325233
kernel :
52335234
func : take_along_axis
52345235
data_type : arr

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ if(WITH_DISTRIBUTE)
6161
py_test_modules(test_fused_gemm_epilogue_rule MODULES
6262
test_fused_gemm_epilogue_rule)
6363
py_test_modules(test_gelu_rule MODULES test_gelu_rule)
64+
py_test_modules(test_take_along_axis_rule MODULES test_take_along_axis_rule)
6465
endif()
6566
# End of unittests WITH single card WITHOUT timeout
6667

0 commit comments

Comments
 (0)