Skip to content

Commit 90e62ce

Browse files
authored
[DistDialect] Dist Interface (#62895)
* dist interface * interface
1 parent c937d8d commit 90e62ce

5 files changed

Lines changed: 137 additions & 1 deletion

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h"
16+
17+
namespace paddle::dialect {} // namespace paddle::dialect
18+
19+
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "paddle/pir/include/core/cast_utils.h"
17+
#include "paddle/pir/include/core/dll_decl.h"
18+
#include "paddle/pir/include/core/type.h"
19+
20+
namespace paddle {
21+
namespace dialect {
22+
23+
class IR_API DistTypeInterface
24+
: public pir::TypeInterfaceBase<DistTypeInterface> {
25+
public:
26+
struct Concept {
27+
/// Defined these methods with the interface.
28+
explicit Concept(pir::Type (*local_type)(pir::Type))
29+
: local_type(local_type) {}
30+
pir::Type (*local_type)(pir::Type);
31+
};
32+
33+
template <class ConcreteType>
34+
struct Model : public Concept {
35+
static Type local_type(Type type) {
36+
return pir::cast<ConcreteType>(type).local_type();
37+
}
38+
Model() : Concept(local_type) {}
39+
};
40+
41+
DistTypeInterface(pir::Type type, Concept *impl)
42+
: pir::TypeInterfaceBase<DistTypeInterface>(type), impl_(impl) {}
43+
44+
pir::Type local_type() { return impl_->local_type(*this); }
45+
46+
private:
47+
Concept *impl_;
48+
};
49+
50+
} // namespace dialect
51+
} // namespace paddle
52+
53+
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface)

paddle/fluid/pir/dialect/distributed/ir/dist_type.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
1616
#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h"
17+
#include "paddle/pir/include/core/ir_context.h"
1718

1819
namespace paddle {
1920
namespace dialect {
@@ -57,6 +58,15 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim,
5758
return local_ddim;
5859
}
5960

61+
auto DistDenseTensorType::local_type() const -> Type {
62+
return pir::DenseTensorType::get(pir::IrContext::Instance(),
63+
dtype(),
64+
local_ddim(),
65+
data_layout(),
66+
lod(),
67+
offset());
68+
}
69+
6070
} // namespace dialect
6171
} // namespace paddle
6272

paddle/fluid/pir/dialect/distributed/ir/dist_type.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
18+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h"
1819
#include "paddle/pir/include/core/builtin_type.h"
1920
#include "paddle/pir/include/core/type.h"
2021

@@ -29,18 +30,23 @@ class DistDenseTensorType
2930
: public pir::Type::TypeBase<DistDenseTensorType,
3031
pir::Type,
3132
DistDenseTensorTypeStorage,
32-
pir::WrapTypeInterface> {
33+
pir::WrapTypeInterface,
34+
DistTypeInterface> {
3335
public:
3436
using Base::Base;
37+
using LoD = pir::DenseTensorTypeStorage::LoD;
3538

3639
pir::DenseTensorType dense_tensor_type() const;
3740
TensorDistAttribute tensor_dist_attr() const;
3841
const common::DDim& global_ddim() const { return dense_tensor_type().dims(); }
3942
const common::DDim& local_ddim() const;
4043
Type dtype() const { return dense_tensor_type().dtype(); }
4144
DataLayout data_layout() const { return dense_tensor_type().data_layout(); }
45+
const LoD& lod() const { return dense_tensor_type().lod(); }
46+
size_t offset() const { return dense_tensor_type().offset(); }
4247

4348
Type prim_type() { return dense_tensor_type(); }
49+
Type local_type() const;
4450

4551
ProcessMeshAttribute process_mesh_attr() const {
4652
return tensor_dist_attr().process_mesh_attr();

test/cpp/pir/distributed/dist_dialect_test.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
1818
#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
19+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h"
1920
#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h"
2021
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
2122
#include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h"
@@ -167,6 +168,53 @@ TEST(dist_dense_tensor_type_test, warp_type_interface) {
167168
dense_tensor_type);
168169
}
169170

171+
TEST(dist_dense_tensor_type_test, dist_interface) {
172+
pir::IrContext* ctx = pir::IrContext::Instance();
173+
ctx->GetOrRegisterDialect<DistDialect>();
174+
ctx->GetOrRegisterDialect<OperatorDialect>();
175+
std::vector<int64_t> mesh_shape = {2, 3};
176+
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
177+
std::vector<std::string> dim_names = {"x", "y"};
178+
phi::distributed::ProcessMesh process_mesh(
179+
mesh_shape, process_ids, dim_names);
180+
auto mesh_attr = ProcessMeshAttribute::get(ctx, process_mesh);
181+
182+
std::vector<int64_t> dims_mapping = {0, -1};
183+
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status{
184+
{1, phi::ReduceType::kRedSum}};
185+
// construct a TensorDistAttribute.
186+
auto tensor_dist_attr =
187+
TensorDistAttribute::get(ctx, mesh_attr, dims_mapping, partial_status);
188+
189+
pir::Type fp32_dtype = pir::Float32Type::get(ctx);
190+
common::DDim dims = {4, 8};
191+
common::DDim local_dims = {2, 8};
192+
common::DataLayout data_layout = common::DataLayout::NCHW;
193+
pir::LoD lod = {{0, 1, 2}};
194+
size_t offset = 0;
195+
pir::DenseTensorType dense_tensor_type = pir::DenseTensorType::get(
196+
ctx, fp32_dtype, dims, data_layout, lod, offset);
197+
198+
pir::Type dist_densor_type =
199+
DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr);
200+
201+
EXPECT_TRUE(dist_densor_type.isa<pir::DenseTensorType>());
202+
EXPECT_EQ(dist_densor_type.dyn_cast<pir::DenseTensorType>(),
203+
dense_tensor_type);
204+
205+
// test local cast
206+
auto local_dense_tensor_type = dist_densor_type.dyn_cast<DistTypeInterface>()
207+
.local_type()
208+
.dyn_cast<pir::DenseTensorType>();
209+
EXPECT_TRUE(local_dense_tensor_type.isa<pir::DenseTensorType>());
210+
EXPECT_FALSE(local_dense_tensor_type.isa<DistDenseTensorType>());
211+
EXPECT_EQ(local_dense_tensor_type.dtype().isa<pir::Float32Type>(), true);
212+
EXPECT_EQ(local_dense_tensor_type.dims(), local_dims);
213+
EXPECT_EQ(local_dense_tensor_type.data_layout(), data_layout);
214+
EXPECT_EQ(local_dense_tensor_type.lod(), lod);
215+
EXPECT_EQ(local_dense_tensor_type.offset(), offset);
216+
}
217+
170218
TEST(operation_dist_attr_test, base) {
171219
pir::IrContext* ctx = pir::IrContext::Instance();
172220
ctx->GetOrRegisterDialect<DistDialect>();

0 commit comments

Comments
 (0)