Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ void TensorDistAttr::set_default_dynamic_dims(
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}

void TensorDistAttr::set_default_dynamic_dims(int64_t tensor_shape_size) {
dynamic_dims_ = std::vector<bool>(tensor_shape_size, false);
}

void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class TEST_API TensorDistAttr {

void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);

void set_default_dynamic_dims(int64_t tensor_shape_size);

const std::map<std::string, bool>& annotated() const { return annotated_; }

void set_annotated(const std::map<std::string, bool>& annotated);
Expand Down
256 changes: 255 additions & 1 deletion paddle/phi/infermeta/spmd_rules/dim_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
// map between from idx in shape to new_shape
std::vector<int64_t> idx_map(shape.size(), -1);
for (int i = 0, n = static_cast<int>(shape.size()); i < n; ++i) {
if (shape[id] != 1) {
if (shape[i] != 1) {
idx_map[i] = static_cast<int64_t>(new_shape.size());
new_shape.emplace_back(shape[i]);
}
Expand Down Expand Up @@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
return ret_dim_trans;
}

std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard(
const std::shared_ptr<DimTrans> dim_trans,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<std::vector<int64_t>>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims) {
DimTrans::Type type = dim_trans->type();
std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;

if (type == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);

if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans.push_back(dim_trans);
}
} else if (type == DimTrans::Type::FLATTEN) {
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs();

int64_t nmesh = (*shardable)[0].size(); // NOLINT
int64_t mesh_shape_prod = 1;

int last_shard_idx = -1;
int64_t first_shard_idx = -1;
int64_t first_sharded_shape = -1;

for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() != DimTrans::Type::INPUTDIM) {
break;
}
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
if (first_shard_idx == -1) {
first_shard_idx = i;
first_sharded_shape = input_shape[inputdim->input_dim()];
}
for (const auto& dim : input_dims_mapping[inputdim->input_dim()]) {
mesh_shape_prod *= mesh_shape[dim];
}
if (first_sharded_shape % mesh_shape_prod == 0) {
ret_dim_trans.push_back(inputdim);
} else {
break;
}
} else {
break;
}
last_shard_idx = i;
}

for (int i = last_shard_idx + 1, n = static_cast<int>(inputs.size()); i < n;
i++) {
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
}

GetDimTransCoShard(input,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
}
} else if (type == DimTrans::Type::SPLIT) {
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
std::vector<std::shared_ptr<DimTrans>> dims =
GetDimTransCoShard(split->input(),
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
int64_t ret_size = split->local_split_shape_value();

if (split->split_id() == 0) {
int64_t mesh_shape_prod = 1;
int64_t first_shard_idx = -1;
int64_t first_sharded_shape = -1;
for (const auto& dim : dims) {
PADDLE_ENFORCE_EQ(dim->type(),
DimTrans::Type::INPUTDIM,
common::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
int64_t input_axis = inputdim->input_dim();

// Check whether the sharded dim can be sharded on
// each mesh dimension. The dimension should be
// divisible by the mesh size that it is sharded on
for (int64_t imesh = 0; imesh < nmesh; imesh++) {
(*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0);
}

if (first_shard_idx == -1) {
first_shard_idx = input_axis;
first_sharded_shape = input_shape[input_axis];
}

if (sharded_input_dims.count(input_axis) > 0) {
for (const auto& dim : input_dims_mapping[input_axis]) {
mesh_shape_prod *= mesh_shape[dim];
}
if ((ret_size % mesh_shape_prod == 0) &&
(first_sharded_shape % mesh_shape_prod == 0)) {
ret_dim_trans.push_back(dim);
} else {
break;
}
} else {
break;
}
}
}
} else if (type == DimTrans::Type::SINGLETON) {
}
return ret_dim_trans;
}

void GetUsedInputDim(const std::shared_ptr<DimTrans> dim_trans,
std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
Expand Down Expand Up @@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
return InferFromDimTrans(input_spec, input_shape, dim_trans);
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
auto input_shape = phi::vectorize(input_spec.dims());
// deal with reshape xshape in dynamic
if (input_shape[0] == 0 &&
input_shape.size() !=
input_spec.dist_attr().multi_dims_mapping().size()) {
input_shape.erase(input_shape.begin());
}
PADDLE_ENFORCE_EQ(input_shape.size(),
input_spec.dist_attr().multi_dims_mapping().size(),
common::errors::InvalidArgument(
"The Tensor X's rank [%d] and X's "
"dims_mapping size [%d] are not matched.",
input_shape.size(),
input_spec.dist_attr().multi_dims_mapping().size()));
return InferFromDimTransCoShard(input_spec, input_shape, dim_trans);
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input,
const std::vector<int64_t>& input_shape,
Expand Down Expand Up @@ -400,4 +554,104 @@ InferFromDimTrans(const DistMetaTensor& input,
return {new_input_dims_mapping, out_dims_mapping};
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
const std::vector<std::vector<int64_t>>& input_dims_mapping =
input.dist_attr().multi_dims_mapping();
const ProcessMesh& mesh = input.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();

std::set<int64_t> sharded_input_dims;
for (int64_t i = 0, n = static_cast<int64_t>(input_dims_mapping.size());
i < n;
++i) {
if (std::any_of(input_dims_mapping[i].begin(),
input_dims_mapping[i].end(),
[](int64_t dim) { return dim > -1; })) {
sharded_input_dims.insert(i);
}
}
int64_t ndim = static_cast<int64_t>(input_shape.size());
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
std::vector<std::vector<bool>> shardable(ndim,
std::vector<bool>(nmesh, true));

std::set<int64_t> seen_input_dims;
for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}

for (int64_t idim = 0; idim < ndim; idim++) {
bool seen = seen_input_dims.count(idim);
if (!seen) {
shardable[idim].assign(nmesh, seen);
}
}

// get the map from sharded input dimensions to output dimensions.
// key is src dim, value is dst dim.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
std::unordered_map<int, std::vector<int>> dim_map_dst2src;
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) {
std::vector<std::shared_ptr<DimTrans>> dims =
GetDimTransCoShard(dim_trans[i],
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
&shardable,
&seen_input_dims);
for (auto dim : dims) {
if (dim->type() == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
dim_map_dst2src[i].push_back(inputdim->input_dim());
}
}
}

std::vector<std::vector<int64_t>> out_dims_mapping(dim_trans.size());
std::vector<std::vector<int64_t>> new_input_dims_mapping(
input_dims_mapping.size());

// set output dims mapping with corresponding input dimensions.
// if one input dimension is sharded on a unshardable mesh after
// splitting, we need to make it replicated.
for (int64_t i = 0; i < ndim; i++) {
const auto& mesh_dims = input_dims_mapping[i];
if (!std::all_of(mesh_dims.begin(),
mesh_dims.end(),
[](int64_t dim) { return dim >= 0; }) ||
dim_map_src2tgt[i] == -1) {
continue;
}

bool is_unshardable = false;
for (const auto& mesh_dim : mesh_dims) {
if (mesh_dim >= 0 && !shardable[i][mesh_dim]) {
is_unshardable = true;
break;
}
}
if (!is_unshardable) {
int dst_dim = dim_map_src2tgt[i];
const auto& src_dims = dim_map_dst2src[dst_dim];
auto min_dim_it = std::min_element(src_dims.begin(), src_dims.end());
int64_t min_dim = *min_dim_it;
out_dims_mapping[dst_dim].insert(
out_dims_mapping[dst_dim].end(), mesh_dims.begin(), mesh_dims.end());
new_input_dims_mapping[min_dim].insert(
new_input_dims_mapping[min_dim].end(),
mesh_dims.begin(),
mesh_dims.end());
}
}

return {new_input_dims_mapping, out_dims_mapping};
}

} // namespace phi::distributed
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/spmd_rules/dim_trans.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,21 @@ std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input_spec,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

} // namespace distributed
} // namespace phi
Loading
Loading