Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/reduce_scatter_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
Expand All @@ -43,51 +46,132 @@ bool PToSReshardFunction::IsSuitable(const DistTensor& in,
return true;
}

void PToSReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call " << Name();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
auto dtype = in.dtype();
const auto& logical_ddim = in.dims();

int out_split_axis =
GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first;

void ReshardPToSWithPadding(DeviceContext* dev_ctx,
int64_t split_axis,
const std::vector<int64_t>& process_ids,
const DenseTensor& in,
int64_t padding_nums,
DenseTensor* out) {
DenseTensor in_reduce_scatter;
std::vector<int> axis;
if (out_split_axis != 0) {
const auto& logical_ddim = in.dims();
auto dtype = in.dtype();

if (split_axis != 0) {
for (size_t i = 0; i < common::vectorize(logical_ddim).size(); ++i) {
axis.emplace_back(i);
}
std::swap(axis[0], axis[out_split_axis]);
RESHARD_FUNCTOR(
dev_ctx, Transpose, dtype, in.value(), axis, &in_reduce_scatter);
std::swap(axis[0], axis[split_axis]);
RESHARD_FUNCTOR(dev_ctx, Transpose, dtype, in, axis, &in_reduce_scatter);
} else {
in_reduce_scatter.ShareDataWith(in.value());
in_reduce_scatter.ShareDataWith(in);
}

DenseTensor out_reduce_scatter;
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
ReduceScatter,
dtype,
in_process_ids,
process_ids,
in_reduce_scatter,
static_cast<int64_t>(in_process_ids.size()),
static_cast<int64_t>(process_ids.size()),
&out_reduce_scatter);

if (out_split_axis != 0) {
DenseTensor out_result;
if (split_axis != 0) {
RESHARD_FUNCTOR(
dev_ctx, Transpose, dtype, out_reduce_scatter, axis, &out_result);
} else {
out_result.ShareDataNoCheckWith(out_reduce_scatter);
}

int64_t cur_global_rank = GetCurGlobalRank();
if (cur_global_rank == process_ids.back() && padding_nums != 0) {
std::vector<DenseTensor> tmp_out_vec;
IntArray tmp_sections(std::vector<int64_t>{
out_result.dims()[split_axis] - padding_nums, padding_nums});
RESHARD_FUNCTOR(dev_ctx,
Transpose,
Split,
dtype,
out_reduce_scatter,
axis,
GetMutableTensor(out));
out_result,
tmp_sections,
split_axis,
&tmp_out_vec);
// TODO(liyurui): Since we can not seperate local tensor with [0, 10] shape
// and uninitialized tensor, here we use a tricky solution.
// Give local tensor which has, for example [0, 10] shape, a little
// allocation, to make it difference from uninitialized tensor in pipelline
// strategy.
if (tmp_out_vec[0].dims()[split_axis] == 0) {
tmp_out_vec[0].mutable_data(tmp_out_vec[0].place(), 4);
}
out->ShareDataNoCheckWith(tmp_out_vec[0]);
} else {
out->ShareDataNoCheckWith(out_result);
}
}

void PToSReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call " << Name();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();

int out_split_axis =
GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first;
int64_t num_of_process = in_process_mesh.size();
int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process;
bool is_balanced_split = (num_of_padding == 0);

if (is_balanced_split) {
VLOG(3) << "Balanced reshard from partial to shard";
ReshardPToSWithPadding(dev_ctx,
out_split_axis,
in_process_ids,
in.value(),
/*padding_nums*/ 0,
GetMutableTensor(out));
} else {
SetValue(out, out_reduce_scatter);
VLOG(3) << "Unbalanced reshard from partial to shard";
int64_t avg_size_on_split_axis =
(in.dims()[out_split_axis] + num_of_process - 1) / num_of_process;
int64_t padding_nums =
avg_size_on_split_axis * num_of_process - in.dims()[out_split_axis];

DDim concat_local_shape = in.local_dims();
concat_local_shape[out_split_axis] = padding_nums;
IntArray concat_local_shape_int_array(concat_local_shape.Get(),
concat_local_shape.size());
auto dtype = in.dtype();

DenseTensor concat_local_tensor;
RESHARD_FUNCTOR(dev_ctx,
Full,
dtype,
concat_local_shape_int_array,
0,
&concat_local_tensor);

DenseTensor in_local_tensor = in.value();
std::vector<const DenseTensor*> concat_input_vec = {&in_local_tensor,
&concat_local_tensor};

DenseTensor concat_result;
RESHARD_FUNCTOR(dev_ctx,
Concat,
dtype,
concat_input_vec,
out_split_axis,
&concat_result);

ReshardPToSWithPadding(dev_ctx,
out_split_axis,
in_process_ids,
concat_result,
padding_nums,
GetMutableTensor(out));
}

SetDistProps(out, in.dims(), out_dist_attr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ std::map<int, int64_t> GetSplitAxisWithDimsMapping(
}

std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces);
int64_t remain_nums = total_nums % num_of_pieces;
for (int64_t i = 0; i < remain_nums; ++i) {
result[i] += 1;
bool has_remainder = (total_nums % num_of_pieces != 0);
std::vector<int64_t> result(num_of_pieces,
(total_nums + num_of_pieces - 1) / num_of_pieces);
if (has_remainder) {
int64_t& last_value = result.back();
last_value = last_value - (last_value * num_of_pieces - total_nums);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx,
int64_t split_axis,
const std::vector<int64_t>& process_ids,
const DenseTensor& in,
int64_t num_of_padding,
int64_t padding_nums,
DenseTensor* out) {
int64_t num_of_process = process_ids.size();
auto dtype = in.dtype();
Expand All @@ -46,7 +46,7 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx,
RESHARD_FUNCTOR_WITH_COMM(
dev_ctx, AllGather, dtype, process_ids, in, num_of_process, out);

if (split_axis != 0 || num_of_padding != 0) {
if (split_axis != 0 || padding_nums != 0) {
IntArray sections(std::vector<int64_t>(num_of_process, in.dims()[0]));

std::vector<DenseTensor> split_out_vec;
Expand All @@ -58,20 +58,18 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx,
/*split_axis*/ 0,
&split_out_vec);

if (num_of_padding != 0) {
for (int64_t i = num_of_padding; i < num_of_process; ++i) {
std::vector<DenseTensor> tmp_out_vec;
IntArray tmp_sections(
std::vector<int64_t>{in.dims()[split_axis] - 1, 1});
RESHARD_FUNCTOR(dev_ctx,
Split,
dtype,
split_out_vec[i],
tmp_sections,
split_axis,
&tmp_out_vec);
split_out_vec[i] = tmp_out_vec[0];
}
if (padding_nums != 0) {
std::vector<DenseTensor> tmp_out_vec;
IntArray tmp_sections(std::vector<int64_t>{
in.dims()[split_axis] - padding_nums, padding_nums});
RESHARD_FUNCTOR(dev_ctx,
Split,
dtype,
split_out_vec[num_of_process - 1],
tmp_sections,
split_axis,
&tmp_out_vec);
split_out_vec[num_of_process - 1] = tmp_out_vec[0];
}

// Concat the result after split on correct axis.
Expand Down Expand Up @@ -124,15 +122,19 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
split_axis,
in_process_ids,
in.value(),
num_of_padding,
/*padding_nums*/ 0,
GetMutableTensor(out));
} else {
VLOG(3) << "Unbalanced reshard from shard to replicated";
bool need_padding =
(in.dims()[split_axis] / num_of_process == in.local_dims()[split_axis]);
int64_t avg_size_on_split_axis =
(in.dims()[split_axis] + num_of_process - 1) / num_of_process;
int64_t padding_nums =
avg_size_on_split_axis * num_of_process - in.dims()[split_axis];
bool need_padding = (in.local_dims()[split_axis] != avg_size_on_split_axis);

if (need_padding) {
DDim concat_local_shape = in.local_dims();
concat_local_shape[split_axis] = 1;
concat_local_shape[split_axis] = padding_nums;
IntArray concat_local_shape_int_array(concat_local_shape.Get(),
concat_local_shape.size());
auto dtype = in.dtype();
Expand All @@ -156,14 +158,14 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
split_axis,
in_process_ids,
concat_result,
num_of_padding,
padding_nums,
GetMutableTensor(out));
} else {
ReshardSToRWithPadding(dev_ctx,
split_axis,
in_process_ids,
in.value(),
num_of_padding,
padding_nums,
GetMutableTensor(out));
}
}
Expand All @@ -173,24 +175,13 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
bool SToRReshardFunctionCrossMesh::IsSuitable(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();

RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

int64_t cur_global_rank = GetCurGlobalRank();
if (in_process_mesh.contains(cur_global_rank)) {
int split_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast<int>(split_axis)] *
num_of_process ==
in.dims()[static_cast<int>(split_axis)]);
}

RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() ==
Expand Down
18 changes: 9 additions & 9 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,19 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_s_to_s MODULES test_reshard_s_to_s)
set_tests_properties(test_reshard_s_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
set_tests_properties(test_reshard_r_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 320)
py_test_modules(test_reshard_p_to_r MODULES test_reshard_p_to_r)
set_tests_properties(test_reshard_p_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 160)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 150)
if(NOT WITH_COVERAGE)
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
set_tests_properties(test_reshard_r_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 160)
py_test_modules(test_reshard_p_to_r MODULES test_reshard_p_to_r)
set_tests_properties(test_reshard_p_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_pipeline_scheduler MODULES test_pipeline_scheduler)
set_tests_properties(test_pipeline_scheduler
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 400)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 150)
endif()
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p
Expand Down
2 changes: 0 additions & 2 deletions test/auto_parallel/reshard_p_to_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def run_test_case(self):

input_tensor = dist.shard_tensor(a, self._mesh, [dist.Partial()])
out = dist.reshard(input_tensor, self._mesh, [dist.Replicate()])
print(input_tensor)
print(out)

assert np.equal(out.shape, input_tensor.shape).all()
np.testing.assert_equal(out._local_value().numpy(), a.numpy())
Expand Down
27 changes: 22 additions & 5 deletions test/auto_parallel/reshard_p_to_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import numpy as np
Expand Down Expand Up @@ -42,19 +43,35 @@ def reshard_same_mesh(self):
input_tensor = dist.shard_tensor(value, self._mesh, [dist.Partial()])

out_shape = list(self._shape)
out_shape[self._shard] = out_shape[self._shard] // 2
split_value_of_front = math.ceil(
out_shape[self._shard] / self._mesh.shape[0]
)
split_value_of_last = (
split_value_of_front
- split_value_of_front * self._mesh.shape[0]
+ out_shape[self._shard]
)

split_sections = [split_value_of_front] * self._mesh.shape[0]

split_sections[len(split_sections) - 1] = split_value_of_last

if dist.get_rank() == self._mesh.process_ids[self._mesh.shape[0] - 1]:
out_shape[self._shard] = split_value_of_last
else:
out_shape[self._shard] = split_value_of_front

out_expected_local_tensor_list = paddle.split(
value, num_or_sections=self._mesh.shape[0], axis=self._shard
value, num_or_sections=split_sections, axis=self._shard
)

out = dist.reshard(input_tensor, self._mesh, [dist.Shard(self._shard)])

np.testing.assert_equal(
out._local_value().numpy(),
out_expected_local_tensor_list[0].numpy()
if dist.get_rank() == 0
else out_expected_local_tensor_list[1].numpy(),
out_expected_local_tensor_list[dist.get_rank()].numpy(),
)
np.testing.assert_equal(out.numpy(), value.numpy())

assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, out_shape).all()
Expand Down
Loading