From 69f92336e1059119d998a43755fbf5d5f9cd8d91 Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Wed, 20 Mar 2024 20:53:43 +0800 Subject: [PATCH 1/2] dist interface --- .../pir/dialect/distributed/ir/dist_type.cc | 10 ++++ .../pir/dialect/distributed/ir/dist_type.h | 8 +++- test/cpp/pir/distributed/dist_dialect_test.cc | 48 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc index 3f0e8968012876..7ee5ed5d3c3fda 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" #include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h" +#include "paddle/pir/include/core/ir_context.h" namespace paddle { namespace dialect { @@ -57,6 +58,15 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim, return local_ddim; } +auto DistDenseTensorType::local_type() const -> Type { + return pir::DenseTensorType::get(pir::IrContext::Instance(), + dtype(), + local_ddim(), + data_layout(), + lod(), + offset()); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h index c8964a516af764..5d58cf99043338 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/type.h" @@ -29,9 +30,11 @@ class DistDenseTensorType : public pir::Type::TypeBase { + pir::WrapTypeInterface, + DistTypeInterface> { public: using Base::Base; + using LoD = pir::DenseTensorTypeStorage::LoD; pir::DenseTensorType dense_tensor_type() const; TensorDistAttribute tensor_dist_attr() const; @@ -39,8 +42,11 @@ class DistDenseTensorType const common::DDim& local_ddim() const; Type dtype() const { return dense_tensor_type().dtype(); } DataLayout data_layout() const { return dense_tensor_type().data_layout(); } + const LoD& lod() const { return dense_tensor_type().lod(); } + size_t offset() const { return dense_tensor_type().offset(); } Type prim_type() { return dense_tensor_type(); } + Type local_type() const; ProcessMeshAttribute process_mesh_attr() const { return tensor_dist_attr().process_mesh_attr(); diff --git a/test/cpp/pir/distributed/dist_dialect_test.cc b/test/cpp/pir/distributed/dist_dialect_test.cc index a273a0e83ff1cc..4a0e477b09ae39 100644 --- a/test/cpp/pir/distributed/dist_dialect_test.cc +++ b/test/cpp/pir/distributed/dist_dialect_test.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" #include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h" @@ -167,6 +168,53 @@ TEST(dist_dense_tensor_type_test, warp_type_interface) { dense_tensor_type); } +TEST(dist_dense_tensor_type_test, dist_interface) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + phi::distributed::ProcessMesh process_mesh( + mesh_shape, process_ids, dim_names); + auto mesh_attr = ProcessMeshAttribute::get(ctx, process_mesh); + + std::vector dims_mapping = {0, -1}; + paddle::flat_hash_map partial_status{ + {1, phi::ReduceType::kRedSum}}; + // construct a TensorDistAttribute. + auto tensor_dist_attr = + TensorDistAttribute::get(ctx, mesh_attr, dims_mapping, partial_status); + + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + common::DDim dims = {4, 8}; + common::DDim local_dims = {2, 8}; + common::DataLayout data_layout = common::DataLayout::NCHW; + pir::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + pir::DenseTensorType dense_tensor_type = pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + + pir::Type dist_densor_type = + DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr); + + EXPECT_TRUE(dist_densor_type.isa()); + EXPECT_EQ(dist_densor_type.dyn_cast(), + dense_tensor_type); + + // test local cast + auto local_dense_tensor_type = dist_densor_type.dyn_cast() + .local_type() + .dyn_cast(); + EXPECT_TRUE(local_dense_tensor_type.isa()); + EXPECT_FALSE(local_dense_tensor_type.isa()); + EXPECT_EQ(local_dense_tensor_type.dtype().isa(), true); + EXPECT_EQ(local_dense_tensor_type.dims(), local_dims); + EXPECT_EQ(local_dense_tensor_type.data_layout(), data_layout); + EXPECT_EQ(local_dense_tensor_type.lod(), lod); + EXPECT_EQ(local_dense_tensor_type.offset(), offset); +} + TEST(operation_dist_attr_test, base) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); From 5de20b4c93b2498d352c515ca12ff58fe6f0a959 Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Wed, 20 Mar 2024 20:57:08 +0800 Subject: [PATCH 2/2] interface --- .../dialect/distributed/ir/dist_interface.cc | 19 +++++++ .../dialect/distributed/ir/dist_interface.h | 53 +++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_interface.h diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc new file mode 100644 index 00000000000000..17e5caa6a22dbb --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" + +namespace paddle::dialect {} // namespace paddle::dialect + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h new file mode 100644 index 00000000000000..dfbb4c1ce47680 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/include/core/cast_utils.h" +#include "paddle/pir/include/core/dll_decl.h" +#include "paddle/pir/include/core/type.h" + +namespace paddle { +namespace dialect { + +class IR_API DistTypeInterface + : public pir::TypeInterfaceBase { + public: + struct Concept { + /// Defined these methods with the interface. + explicit Concept(pir::Type (*local_type)(pir::Type)) + : local_type(local_type) {} + pir::Type (*local_type)(pir::Type); + }; + + template + struct Model : public Concept { + static Type local_type(Type type) { + return pir::cast(type).local_type(); + } + Model() : Concept(local_type) {} + }; + + DistTypeInterface(pir::Type type, Concept *impl) + : pir::TypeInterfaceBase(type), impl_(impl) {} + + pir::Type local_type() { return impl_->local_type(*this); } + + private: + Concept *impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface)