diff --git a/paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.cc index b94cd3e2e084c0..f0944c710aabb9 100644 --- a/paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.cc +++ b/paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.cc @@ -40,8 +40,23 @@ COMMON_DECLARE_bool(print_ir); namespace paddle { namespace dialect { -inline pir::Type CastToLocalType(pir::Type dist_type) { - return dist_type.dyn_cast().local_type(); +pir::Type CastToLocalType(pir::Type type) { + if (auto dist_type = type.dyn_cast()) { + return dist_type.local_type(); + } else if (auto vec_type = type.dyn_cast()) { + std::vector local_types; + for (size_t i = 0; i < vec_type.size(); ++i) { + local_types.push_back(CastToLocalType(vec_type[i])); + } + return pir::VectorType::get(vec_type.ir_context(), local_types); + } else if (!type) { + // skip if <> + return nullptr; + } else { + // TODO(2024-Q2) not all value are dist type + PADDLE_THROW(common::errors::PreconditionNotMet( + "The type[%s] is not Dist type.", type)); + } } inline bool IsDistType(pir::Type type) { return type.isa(); } @@ -53,17 +68,15 @@ void ProcessDistBlock(pir::Block* block) { for (size_t i = 0; i < op_item->num_results(); ++i) { auto result = op_item->result(i); - auto origin_type = result.type(); - if (IsDistType(origin_type)) { - auto local_type = CastToLocalType(origin_type); - result.set_type(local_type); - } else if (origin_type) { // skip if <> - // TODO(2024-Q2) not all value are dist type - PADDLE_THROW(platform::errors::PreconditionNotMet( - "The op [%s]'s [%d]th result is not Dist type.", - op_item->name(), - i)); - } + result.set_type(CastToLocalType(result.type())); + } + if (op_item->isa()) { + auto dense_tensor_type = + op_item->result(0).type().dyn_cast(); + auto shape = common::vectorize(dense_tensor_type.dims()); + pir::Attribute attr_shape = IntArrayAttribute::get( + pir::IrContext::Instance(), phi::IntArray(shape)); + op_item->set_attribute("shape", attr_shape); } // TODO(2024-Q2) not all op are dist type // PADDLE_ENFORCE_EQ( diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index a2c6873f2372e8..f091b70755ab87 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -70,7 +70,7 @@ def __init__(self, layer, loss_func, metrics): @paddle.jit.not_to_static def append_loss_to_shadow_output(self, mode): name = paddle.utils.unique_name.generate('loss') - paddle._pir_ops.set_persistable_value(self._loss_vars[mode], name) + paddle._C_ops.set_persistable_value(self._loss_vars[mode], name) self._loss_names[mode] = name def _train(self, inputs, labels): diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py index 6c0c9445449938..8956cc2535d9b0 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py @@ -48,7 +48,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): reduce_mean = True group = new_process_group(src_mesh.process_ids) - reduced_value = paddle._pir_ops.c_allreduce_sum_( + reduced_value = paddle._C_ops.c_allreduce_sum_( src_value, group.id, True, False ) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py index 6cb3fe4916e08d..1f4d01970ebdac 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py @@ -89,8 +89,8 @@ def reshard_s_to_r_with_padding( num_of_process = len(src_mesh.process_ids) dtype = src_value.dtype group = new_process_group(src_mesh.process_ids) - allgather_value = paddle._pir_ops.c_allgather( - src_value, group.id, num_of_process, False + allgather_value = paddle._C_ops.c_allgather( + src_value, group.id, num_of_process, True ) allgather_value.set_type(dst_type) @@ -109,10 +109,10 @@ def reshard_s_to_r_with_padding( if split_axis != 0 or padding_num != 0: allgather_op = allgather_value.get_defining_op() paddle.pir.set_insertion_point_after(allgather_op) - split_value = paddle._pir_ops.split_with_num( + split_value = paddle._C_ops.split_with_num( allgather_op.result(0), num_of_process, 0 ) - concat_value = paddle._pir_ops.concat(split_value, split_axis) + concat_value = paddle._C_ops.concat(split_value, split_axis) return concat_value return allgather_value diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py index 1e79afbaf0ab06..c69aeed5e2e6d4 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py @@ -50,7 +50,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): for src, dst in zip(src_mesh.process_ids, dst_mesh.process_ids): if src == cur_global_rank: dst_local_rank = all_process_ids.index(dst) - paddle._pir_ops.send_v2( + paddle._C_ops.send_v2( src_value, comm_group.id, dst_local_rank, @@ -73,7 +73,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): assert ( -1 not in dst_type.shape ), "dynamic shape is not supported by pir-auto parallel yet." - recv_value = paddle._pir_ops.recv_v2( + recv_value = paddle._C_ops.recv_v2( dst_type.shape, dst_type.dtype, src_local_rank, diff --git a/test/auto_parallel/pir/CMakeLists.txt b/test/auto_parallel/pir/CMakeLists.txt index f88e1388e679a4..b012a7065baef8 100644 --- a/test/auto_parallel/pir/CMakeLists.txt +++ b/test/auto_parallel/pir/CMakeLists.txt @@ -24,7 +24,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_mlp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60) set_tests_properties(test_semi_auto_parallel_dist_to_static_pir - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 30) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60) py_test_modules( test_eliminate_transpose_pass MODULES test_eliminate_transpose_pass ENVS FLAGS_enable_pir_in_executor=1) diff --git a/test/auto_parallel/pir/mlp_demo.py b/test/auto_parallel/pir/mlp_demo.py index e8c3532830c87c..aa92c3ffe489d0 100644 --- a/test/auto_parallel/pir/mlp_demo.py +++ b/test/auto_parallel/pir/mlp_demo.py @@ -65,6 +65,40 @@ def forward(self, x): return out +class DPDemoNet(nn.Layer): + def __init__( + self, + mesh, + ): + super().__init__() + self._mesh = mesh + self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE, bias_attr=False) + self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM, bias_attr=False) + self.linear_0.weight = dist.shard_tensor( + self.linear_0.weight, + self._mesh, + [dist.Replicate()], + stop_gradient=False, + ) + self.linear_1.weight = dist.shard_tensor( + self.linear_1.weight, + self._mesh, + [dist.Replicate()], + stop_gradient=False, + ) + self.relu_0 = nn.ReLU() + self.relu_1 = nn.ReLU() + self.relu_2 = nn.ReLU() + + def forward(self, x): + out = self.relu_0(x) + out = self.linear_0(out) + out = self.relu_1(out) + out = self.linear_1(out) + out = self.relu_2(out) + return out + + class TestMLPTensorParallel(unittest.TestCase): def test_to_static_program(self): paddle.base.set_flags({'FLAGS_enable_pir_api': 1}) diff --git a/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py b/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py index 6d71d9f3d85b5b..b7ebff1f54df39 100644 --- a/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py +++ b/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py @@ -16,7 +16,7 @@ import random import numpy as np -from mlp_demo import PPDemoNet +from mlp_demo import DPDemoNet, PPDemoNet from test_to_static_pir_program import DemoNet import paddle @@ -160,6 +160,47 @@ def test_mp_demo_net(self): dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader) np.testing.assert_allclose(dy_losses, dy2static_losses, atol=1e-7) + def test_dp_demo_net(self): + paddle.disable_static() + self.set_random_seed(self._seed) + data_loader = self.create_data_loader() + + self.set_random_seed(self._seed) + dy_layer = DPDemoNet(self.mesh) + dy_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy_layer.parameters() + ) + + paddle.base.set_flags({'FLAGS_enable_pir_api': 1}) + self.set_random_seed(self._seed) + dy2static_layer = DPDemoNet(self.mesh) + dy2static_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy2static_layer.parameters() + ) + dist_dataloader = dist.shard_dataloader( + dataloader=data_loader, + meshes=[self.mesh], + input_keys=["image", "label"], + shard_dims=['x'], + ) + dy2static_losses, _ = self.run_dy2static( + dy2static_layer, dy2static_opt, dist_dataloader + ) + + dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader) + # Check the loss values. Different from dygraph mode, when + # the model is trained in dy2static mode, the loss values + # are not the average of the losses of all processes, so + # we should get the average loss first. + paddle.disable_static() + pd_partial_loss = paddle.to_tensor(dy2static_losses) + pd_loss_list = [] + dist.all_gather(pd_loss_list, pd_partial_loss) + np_dy2static_loss_list = [loss.numpy() for loss in pd_loss_list] + np_dy2static_loss = np.array(np_dy2static_loss_list) + np_dy2static_loss = np.mean(np_dy2static_loss, axis=0) + np.testing.assert_array_equal(dy_losses, np_dy2static_loss) + def test_pp_demo_net(self): paddle.disable_static() self.set_random_seed(self._seed) @@ -194,6 +235,7 @@ def test_pp_demo_net(self): def run_test_case(self): self.test_mp_demo_net() self.test_pp_demo_net() + self.test_dp_demo_net() if __name__ == '__main__': diff --git a/test/auto_parallel/reshard_p_to_r_cross_mesh.py b/test/auto_parallel/reshard_p_to_r_cross_mesh.py index 42a34a478a7ffb..6960530bf3bb31 100644 --- a/test/auto_parallel/reshard_p_to_r_cross_mesh.py +++ b/test/auto_parallel/reshard_p_to_r_cross_mesh.py @@ -82,7 +82,7 @@ def run_pir_static_test_case(self): input_tensor = dist.shard_tensor( w0, self._in_mesh, [dist.Partial(dist.ReduceType.kRedSum)] ) - reshard_tensor = paddle._pir_ops.reshard( + reshard_tensor = paddle._C_ops.reshard( input_tensor, self._out_mesh, [dist.Replicate()] )