|
16 | 16 |
|
17 | 17 | #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" |
18 | 18 | #include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" |
| 19 | +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" |
19 | 20 | #include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" |
20 | 21 | #include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" |
21 | 22 | #include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h" |
@@ -167,6 +168,53 @@ TEST(dist_dense_tensor_type_test, warp_type_interface) { |
167 | 168 | dense_tensor_type); |
168 | 169 | } |
169 | 170 |
|
| 171 | +TEST(dist_dense_tensor_type_test, dist_interface) { |
| 172 | + pir::IrContext* ctx = pir::IrContext::Instance(); |
| 173 | + ctx->GetOrRegisterDialect<DistDialect>(); |
| 174 | + ctx->GetOrRegisterDialect<OperatorDialect>(); |
| 175 | + std::vector<int64_t> mesh_shape = {2, 3}; |
| 176 | + std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5}; |
| 177 | + std::vector<std::string> dim_names = {"x", "y"}; |
| 178 | + phi::distributed::ProcessMesh process_mesh( |
| 179 | + mesh_shape, process_ids, dim_names); |
| 180 | + auto mesh_attr = ProcessMeshAttribute::get(ctx, process_mesh); |
| 181 | + |
| 182 | + std::vector<int64_t> dims_mapping = {0, -1}; |
| 183 | + paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status{ |
| 184 | + {1, phi::ReduceType::kRedSum}}; |
| 185 | + // construct a TensorDistAttribute. |
| 186 | + auto tensor_dist_attr = |
| 187 | + TensorDistAttribute::get(ctx, mesh_attr, dims_mapping, partial_status); |
| 188 | + |
| 189 | + pir::Type fp32_dtype = pir::Float32Type::get(ctx); |
| 190 | + common::DDim dims = {4, 8}; |
| 191 | + common::DDim local_dims = {2, 8}; |
| 192 | + common::DataLayout data_layout = common::DataLayout::NCHW; |
| 193 | + pir::LoD lod = {{0, 1, 2}}; |
| 194 | + size_t offset = 0; |
| 195 | + pir::DenseTensorType dense_tensor_type = pir::DenseTensorType::get( |
| 196 | + ctx, fp32_dtype, dims, data_layout, lod, offset); |
| 197 | + |
| 198 | + pir::Type dist_densor_type = |
| 199 | + DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr); |
| 200 | + |
| 201 | + EXPECT_TRUE(dist_densor_type.isa<pir::DenseTensorType>()); |
| 202 | + EXPECT_EQ(dist_densor_type.dyn_cast<pir::DenseTensorType>(), |
| 203 | + dense_tensor_type); |
| 204 | + |
| 205 | + // test local cast |
| 206 | + auto local_dense_tensor_type = dist_densor_type.dyn_cast<DistTypeInterface>() |
| 207 | + .local_type() |
| 208 | + .dyn_cast<pir::DenseTensorType>(); |
| 209 | + EXPECT_TRUE(local_dense_tensor_type.isa<pir::DenseTensorType>()); |
| 210 | + EXPECT_FALSE(local_dense_tensor_type.isa<DistDenseTensorType>()); |
| 211 | + EXPECT_EQ(local_dense_tensor_type.dtype().isa<pir::Float32Type>(), true); |
| 212 | + EXPECT_EQ(local_dense_tensor_type.dims(), local_dims); |
| 213 | + EXPECT_EQ(local_dense_tensor_type.data_layout(), data_layout); |
| 214 | + EXPECT_EQ(local_dense_tensor_type.lod(), lod); |
| 215 | + EXPECT_EQ(local_dense_tensor_type.offset(), offset); |
| 216 | +} |
| 217 | + |
170 | 218 | TEST(operation_dist_attr_test, base) { |
171 | 219 | pir::IrContext* ctx = pir::IrContext::Instance(); |
172 | 220 | ctx->GetOrRegisterDialect<DistDialect>(); |
|
0 commit comments