Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,23 @@ COMMON_DECLARE_bool(print_ir);
namespace paddle {
namespace dialect {

inline pir::Type CastToLocalType(pir::Type dist_type) {
return dist_type.dyn_cast<DistTypeInterface>().local_type();
pir::Type CastToLocalType(pir::Type type) {
if (auto dist_type = type.dyn_cast<DistTypeInterface>()) {
return dist_type.local_type();
} else if (auto vec_type = type.dyn_cast<pir::VectorType>()) {
std::vector<pir::Type> local_types;
for (size_t i = 0; i < vec_type.size(); ++i) {
local_types.push_back(CastToLocalType(vec_type[i]));
}
return pir::VectorType::get(vec_type.ir_context(), local_types);
} else if (!type) {
// skip if <<NULL TYPE>>
return nullptr;
} else {
// TODO(2024-Q2) not all value are dist type
PADDLE_THROW(common::errors::PreconditionNotMet(
"The type[%s] is not Dist type.", type));
}
}

inline bool IsDistType(pir::Type type) { return type.isa<DistTypeInterface>(); }
Expand All @@ -53,17 +68,15 @@ void ProcessDistBlock(pir::Block* block) {

for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result = op_item->result(i);
auto origin_type = result.type();
if (IsDistType(origin_type)) {
auto local_type = CastToLocalType(origin_type);
result.set_type(local_type);
} else if (origin_type) { // skip if <<NULL TYPE>>
// TODO(2024-Q2) not all value are dist type
PADDLE_THROW(platform::errors::PreconditionNotMet(
"The op [%s]'s [%d]th result is not Dist type.",
op_item->name(),
i));
}
result.set_type(CastToLocalType(result.type()));
}
if (op_item->isa<DataOp>()) {
auto dense_tensor_type =
op_item->result(0).type().dyn_cast<pir::DenseTensorType>();
auto shape = common::vectorize(dense_tensor_type.dims());
pir::Attribute attr_shape = IntArrayAttribute::get(
pir::IrContext::Instance(), phi::IntArray(shape));
op_item->set_attribute("shape", attr_shape);
}
// TODO(2024-Q2) not all op are dist type
// PADDLE_ENFORCE_EQ(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, layer, loss_func, metrics):
@paddle.jit.not_to_static
def append_loss_to_shadow_output(self, mode):
name = paddle.utils.unique_name.generate('loss')
paddle._pir_ops.set_persistable_value(self._loss_vars[mode], name)
paddle._C_ops.set_persistable_value(self._loss_vars[mode], name)
self._loss_names[mode] = name

def _train(self, inputs, labels):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
reduce_mean = True

group = new_process_group(src_mesh.process_ids)
reduced_value = paddle._pir_ops.c_allreduce_sum_(
reduced_value = paddle._C_ops.c_allreduce_sum_(
src_value, group.id, True, False
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def reshard_s_to_r_with_padding(
num_of_process = len(src_mesh.process_ids)
dtype = src_value.dtype
group = new_process_group(src_mesh.process_ids)
allgather_value = paddle._pir_ops.c_allgather(
src_value, group.id, num_of_process, False
allgather_value = paddle._C_ops.c_allgather(
src_value, group.id, num_of_process, True
)
allgather_value.set_type(dst_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should think about a better method to remove this hack, like communication operation in pir using special infermeta_local

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, I will systematically solve this problem in the next pr.


Expand All @@ -109,10 +109,10 @@ def reshard_s_to_r_with_padding(
if split_axis != 0 or padding_num != 0:
allgather_op = allgather_value.get_defining_op()
paddle.pir.set_insertion_point_after(allgather_op)
split_value = paddle._pir_ops.split_with_num(
split_value = paddle._C_ops.split_with_num(
allgather_op.result(0), num_of_process, 0
)
concat_value = paddle._pir_ops.concat(split_value, split_axis)
concat_value = paddle._C_ops.concat(split_value, split_axis)
return concat_value
return allgather_value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
for src, dst in zip(src_mesh.process_ids, dst_mesh.process_ids):
if src == cur_global_rank:
dst_local_rank = all_process_ids.index(dst)
paddle._pir_ops.send_v2(
paddle._C_ops.send_v2(
src_value,
comm_group.id,
dst_local_rank,
Expand All @@ -73,7 +73,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
assert (
-1 not in dst_type.shape
), "dynamic shape is not supported by pir-auto parallel yet."
recv_value = paddle._pir_ops.recv_v2(
recv_value = paddle._C_ops.recv_v2(
dst_type.shape,
dst_type.dtype,
src_local_rank,
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_mlp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT
60)
set_tests_properties(test_semi_auto_parallel_dist_to_static_pir
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 30)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60)
py_test_modules(
test_eliminate_transpose_pass MODULES test_eliminate_transpose_pass ENVS
FLAGS_enable_pir_in_executor=1)
Expand Down
34 changes: 34 additions & 0 deletions test/auto_parallel/pir/mlp_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ def forward(self, x):
return out


class DPDemoNet(nn.Layer):
def __init__(
self,
mesh,
):
super().__init__()
self._mesh = mesh
self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE, bias_attr=False)
self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM, bias_attr=False)
self.linear_0.weight = dist.shard_tensor(
self.linear_0.weight,
self._mesh,
[dist.Replicate()],
stop_gradient=False,
)
self.linear_1.weight = dist.shard_tensor(
self.linear_1.weight,
self._mesh,
[dist.Replicate()],
stop_gradient=False,
)
self.relu_0 = nn.ReLU()
self.relu_1 = nn.ReLU()
self.relu_2 = nn.ReLU()

def forward(self, x):
out = self.relu_0(x)
out = self.linear_0(out)
out = self.relu_1(out)
out = self.linear_1(out)
out = self.relu_2(out)
return out


class TestMLPTensorParallel(unittest.TestCase):
def test_to_static_program(self):
paddle.base.set_flags({'FLAGS_enable_pir_api': 1})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random

import numpy as np
from mlp_demo import PPDemoNet
from mlp_demo import DPDemoNet, PPDemoNet
from test_to_static_pir_program import DemoNet

import paddle
Expand Down Expand Up @@ -160,6 +160,47 @@ def test_mp_demo_net(self):
dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader)
np.testing.assert_allclose(dy_losses, dy2static_losses, atol=1e-7)

def test_dp_demo_net(self):
paddle.disable_static()
self.set_random_seed(self._seed)
data_loader = self.create_data_loader()

self.set_random_seed(self._seed)
dy_layer = DPDemoNet(self.mesh)
dy_opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=dy_layer.parameters()
)

paddle.base.set_flags({'FLAGS_enable_pir_api': 1})
self.set_random_seed(self._seed)
dy2static_layer = DPDemoNet(self.mesh)
dy2static_opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=dy2static_layer.parameters()
)
dist_dataloader = dist.shard_dataloader(
dataloader=data_loader,
meshes=[self.mesh],
input_keys=["image", "label"],
shard_dims=['x'],
)
dy2static_losses, _ = self.run_dy2static(
dy2static_layer, dy2static_opt, dist_dataloader
)

dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader)
# Check the loss values. Different from dygraph mode, when
# the model is trained in dy2static mode, the loss values
# are not the average of the losses of all processes, so
# we should get the average loss first.
paddle.disable_static()
pd_partial_loss = paddle.to_tensor(dy2static_losses)
pd_loss_list = []
dist.all_gather(pd_loss_list, pd_partial_loss)
np_dy2static_loss_list = [loss.numpy() for loss in pd_loss_list]
np_dy2static_loss = np.array(np_dy2static_loss_list)
np_dy2static_loss = np.mean(np_dy2static_loss, axis=0)
np.testing.assert_array_equal(dy_losses, np_dy2static_loss)

def test_pp_demo_net(self):
paddle.disable_static()
self.set_random_seed(self._seed)
Expand Down Expand Up @@ -194,6 +235,7 @@ def test_pp_demo_net(self):
def run_test_case(self):
self.test_mp_demo_net()
self.test_pp_demo_net()
self.test_dp_demo_net()


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/reshard_p_to_r_cross_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run_pir_static_test_case(self):
input_tensor = dist.shard_tensor(
w0, self._in_mesh, [dist.Partial(dist.ReduceType.kRedSum)]
)
reshard_tensor = paddle._pir_ops.reshard(
reshard_tensor = paddle._C_ops.reshard(
input_tensor, self._out_mesh, [dist.Replicate()]
)

Expand Down