Skip to content

Commit 966a743

Browse files
committed
fix reshard dist_attr
1 parent a08580e commit 966a743

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

paddle/phi/core/distributed/auto_parallel/dist_tensor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ DistTensor::DistTensor() : value_(std::make_shared<DenseTensor>()) {}
117117
DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
118118
const TensorDistAttr& dist_attr)
119119
: global_dims_(global_value->dims()), dist_attr_(dist_attr) {
120+
process_mesh_ = dist_attr_.process_mesh();
121+
placements_ = ToPlacements(dist_attr);
122+
120123
// If the current rank doesn't in process_mesh, we should create an
121124
// uninitialized tensor only with tensor_meta.
122125
if (IsCurRankInMesh(dist_attr.process_mesh())) {

paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
5252

5353
tensor->global_dims_ = dims;
5454
tensor->dist_attr_ = dist_attr;
55+
tensor->process_mesh_ = dist_attr.process_mesh();
56+
tensor->placements_ = ToPlacements(dist_attr);
5557
}
5658

5759
void ReshardFunction::SetDistProps(DistTensor* tensor,
@@ -64,6 +66,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
6466
str_join(vectorize(tensor->dims()))));
6567

6668
tensor->dist_attr_ = dist_attr;
69+
tensor->process_mesh_ = dist_attr.process_mesh();
70+
tensor->placements_ = ToPlacements(dist_attr);
6771
}
6872

6973
DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {

0 commit comments

Comments
 (0)