Skip to content

Commit 7404342

Browse files
support to fallback to sharding by one mesh dim from by more than one (#74396)
* support to fallback to shard by one mesh dim from by more than one * add test case
1 parent ee517b2 commit 7404342

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

paddle/phi/core/distributed/auto_parallel/dist_attr.cc

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -560,15 +560,31 @@ TensorDistAttr::DimMapProxy::operator const std::vector<int64_t>&() const {
560560
void TensorDistAttr::DimMapProxy::sync_1d_map() const {
561561
dims_mapping_1d.resize(dims_mapping_2d->size());
562562
for (size_t i = 0; i < dims_mapping_2d->size(); ++i) {
563-
PADDLE_ENFORCE_LE(dims_mapping_2d->at(i).size(),
564-
1,
565-
common::errors::InvalidArgument(
566-
"There are %d mesh dim sharded on tensor dim %d,"
567-
"you should call \"multi_dims_mapping()\"",
568-
dims_mapping_2d->at(i).size(),
569-
i));
570-
dims_mapping_1d[i] =
571-
(*dims_mapping_2d)[i].empty() ? -1 : (*dims_mapping_2d)[i][0];
563+
size_t num_mesh_dim = dims_mapping_2d->at(i).size();
564+
if (num_mesh_dim <= 1) {
565+
dims_mapping_1d[i] =
566+
(*dims_mapping_2d)[i].empty() ? -1 : (*dims_mapping_2d)[i][0];
567+
continue;
568+
}
569+
570+
int64_t max_mesh_dim = (*dims_mapping_2d)[i][0];
571+
int64_t max_mesh_dim_size = process_mesh.shape()[max_mesh_dim];
572+
573+
for (size_t j = 1; j < num_mesh_dim; ++j) {
574+
int64_t cur_mesh_dim = (*dims_mapping_2d)[i][j];
575+
int64_t cur_mesh_dim_size = process_mesh.shape()[cur_mesh_dim];
576+
577+
if (cur_mesh_dim_size > max_mesh_dim_size) {
578+
max_mesh_dim = cur_mesh_dim;
579+
max_mesh_dim_size = cur_mesh_dim_size;
580+
}
581+
}
582+
583+
LOG(WARNING) << "There are " << num_mesh_dim << " shared on tensor dim "
584+
<< i << ". Now fallback to sharding by mesh dim "
585+
<< max_mesh_dim << ".";
586+
587+
dims_mapping_1d[i] = max_mesh_dim;
572588
}
573589
}
574590

paddle/phi/core/distributed/auto_parallel/dist_attr.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,9 @@ class TEST_API TensorDistAttr {
231231
// delete it after all 1d vector dims_mapping_ have been upgraded to 2d.
232232
class DimMapProxy final {
233233
public:
234-
DimMapProxy(std::vector<std::vector<int64_t>>* dims_mapping_2d)
235-
: dims_mapping_2d(dims_mapping_2d) {}
234+
DimMapProxy(std::vector<std::vector<int64_t>>* dims_mapping_2d,
235+
const ProcessMesh& process_mesh)
236+
: dims_mapping_2d(dims_mapping_2d), process_mesh(process_mesh) {}
236237

237238
DimMapProxy& operator=(
238239
const std::vector<std::vector<int64_t>>& dims_mapping);
@@ -248,6 +249,7 @@ class TEST_API TensorDistAttr {
248249
void sync_2d_map();
249250
mutable std::vector<int64_t> dims_mapping_1d;
250251
std::vector<std::vector<int64_t>>* dims_mapping_2d;
252+
const ProcessMesh& process_mesh;
251253
};
252254

253255
static std::vector<std::string> fields_;
@@ -266,7 +268,7 @@ class TEST_API TensorDistAttr {
266268

267269
std::vector<std::vector<int64_t>> dims_mapping_;
268270
// for short time, backward compatible for existing spmd relus.
269-
DimMapProxy dims_mapping_proxy{&dims_mapping_};
271+
DimMapProxy dims_mapping_proxy{&dims_mapping_, process_mesh_};
270272
};
271273

272274
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {

test/auto_parallel/co_shard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def run_test_case_4(self):
189189
str(relu_out.placements[1]), "Shard(dim=0, shard_order=1)"
190190
)
191191

192+
# test fallback to shard by one dim.
193+
add_out = paddle.add(relu_out, relu_out)
194+
np.testing.assert_equal(str(add_out.placements[0]), "Shard(dim=0)")
195+
np.testing.assert_equal(str(add_out.placements[1]), "Replicate()")
196+
192197
def run_test_case_main(self):
193198
self.basic_interface_case()
194199
self.run_test_case_0()

0 commit comments

Comments
 (0)