Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,28 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
return;
}
if (x->is_dist_tensor()) {
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
"have the same mesh.",
x->name()));
auto dist_ptr =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl());
if (!dist_ptr->skip_check_mesh() && x->dims().size() > 0) {
// NOTE(pkuzyc): In MoE expert parallelism, the mesh of the
// inputs and outputs of different experts are different, so
// skip checking mesh in the following two casees:
// 1. The ``skip_check_mesh_`` flag is true. The MoE-related apis
// sets this flag to indicate that the difference between tensor's
// mesh is allowed.
// 2. The tensor is a 0-D tensor. Specifically, in MoE expert
// parallelism, the learning rate's mesh is global, but expert
// weights' mesh is the subset of the global mesh, this is also
// allowed so skip checking the mesh of 0-D tensor.
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
"have the same mesh.",
x->name()));
}
return;
} else {
PADDLE_ENFORCE_EQ(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,8 @@ void BindTensor(pybind11::module &m) { // NOLINT
}
return self;
})
.def("_unsafe_set_skip_check_mesh",
&DistTensor::unsafe_set_skip_check_mesh)
.def("_clear", &DistTensor::clear);
#endif

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3007,6 +3007,7 @@
output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]}
infer_meta :
func : UnbindInferMeta
spmd_rule : UnbindInferSpmdDynamic
kernel :
func : unbind
backward : unbind_grad
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
std::swap(this->partial_status_, tmp.partial_status_);
std::swap(this->skip_check_mesh_, tmp.skip_check_mesh_);
return *this;
}

Expand All @@ -60,6 +61,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
set_partial_status(dist_attr.partial_status());
skip_check_mesh_ = dist_attr.skip_check_mesh();
}

void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
Expand Down Expand Up @@ -289,6 +291,7 @@ std::string TensorDistAttr::to_string() const {
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "chunk_id: " + std::to_string(chunk_id_) + ", ";
dist_str += "skip_check_mesh: " + std::to_string(skip_check_mesh_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "], ";
dist_str += "partial: " + partial_status_string() + ".}";
Expand Down Expand Up @@ -445,5 +448,7 @@ bool TensorDistAttr::is_partial(int64_t mesh_axis) const {
}
}

void TensorDistAttr::set_skip_check_mesh(bool skip) { skip_check_mesh_ = skip; }

} // namespace distributed
} // namespace phi
6 changes: 6 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class TEST_API TensorDistAttr {
// if mesh_axis is not -1, check only on specific axis.
bool is_partial(int64_t mesh_axis = -1) const;

void set_skip_check_mesh(bool skip);

bool skip_check_mesh() const { return skip_check_mesh_; }

private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
Expand All @@ -209,6 +213,8 @@ class TEST_API TensorDistAttr {
// iterate operation (copy and comparison) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
paddle::flat_hash_map<int64_t, ReduceType> partial_status_;
// The flag indicates whether to skip checking the process mesh.
bool skip_check_mesh_ = false;
};

inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ void* DistTensor::AllocateFrom(Allocator* allocator,
return nullptr;
}

void DistTensor::unsafe_set_skip_check_mesh(bool skip) {
VLOG(6) << "You try to set an initialized DistTensor's dist attr. "
"Make sure you are aware of where you change its dist attr.";
dist_attr_.set_skip_check_mesh(skip);
}

void DistTensor::clear() {
if (value_) {
value_->clear();
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ class DistTensor final
size_t requested_size = 0,
bool fake_alloc = false) override;

/// \brief Set the flag indicating whether to skip checking the process mesh.
/// \note Currently only used for the MoE apis,
/// it receives the inputs with different process meshes and outputs the dist
/// tensor with global process mesh.
/// \return void
void unsafe_set_skip_check_mesh(bool skip);

bool skip_check_mesh() const { return dist_attr_.skip_check_mesh(); }

void clear();

private:
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim);

SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
int reduce_type);

// This infer spmd function only use in dynamic mode for it uses
// IntArray as parameter. The IntArray may contain vector of tensor
// which is not support in static mode. So we separate these two and
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ PD_REGISTER_SPMD_RULE(swiglu,
// TODO(pkuzyc): add multiary elementwise rule

// reduction rule
PD_REGISTER_SPMD_RULE(
reduce_base,
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdBase),
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
all,
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
Expand Down
Loading