Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions python/paddle/distributed/auto_parallel/pipelining/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,29 @@ def _split_tensor(x, num_chunks, split_axis=0):
chunk_tensors = paddle.tensor_split(x, num_chunks, split_axis)
# dp_degree > 1 , placements of model input is [S(0), R, ...]
else:
if dist.in_auto_parallel_align_mode():

def _reorder_data_for_align():
nonlocal x
assert x.placements[0] == dist.Shard(
0
), "inputs should be placed on S(0)."

shardings = x.process_mesh.shape[0]

rows_per_shard = x.shape[0] // shardings
new_indices = []
for s_id in range(shardings):
for row_in_shard in range(rows_per_shard):
new_indices.append(s_id + row_in_shard * shardings)
tmp = x[new_indices]
x = dist.reshard(tmp, x.process_mesh, x.placements)

_reorder_data_for_align()
mesh = x.process_mesh
placements = x.placements
x = dtensor_to_local(x, mesh, placements)
chunk_tensors = paddle.tensor_split(x, num_chunks, split_axis)
dense_x = dtensor_to_local(x, mesh, placements)
chunk_tensors = paddle.tensor_split(dense_x, num_chunks, split_axis)
for i in range(num_chunks):
chunk_tensors[i] = dtensor_from_local(
chunk_tensors[i], mesh, placements
Expand Down
89 changes: 87 additions & 2 deletions test/auto_parallel/PP_Schedules_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_dp_pp(self):
loader = DataLoader(dataset, batch_size=8)
losses_by_step = []
num_iterations = 20
all_losses_in_one_step_md5sum = []
for iter_idx in range(num_iterations):
losses_by_micro_batch = []
for i, (data, label) in enumerate(loader):
Expand All @@ -399,6 +400,10 @@ def test_dp_pp(self):
local_loss, op=dist.ReduceOp.AVG, group=dp_group
)
reduced_losses.append(local_loss)
if iter_idx == 0:
all_losses_in_one_step_md5sum.append(
local_loss._md5sum()
)

if self.rank == 3:
# Calculate mean using reduced losses
Expand All @@ -407,7 +412,7 @@ def test_dp_pp(self):
)
opt.step()
opt.clear_grad()
return losses_by_step
return losses_by_step, all_losses_in_one_step_md5sum

def test_pp_model_with_ClipGradByGlobalNorm(self):
"""Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline"""
Expand Down Expand Up @@ -470,6 +475,75 @@ def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self):
opt.clear_grad()
return losses_by_step

def test_dp_pp_align_mode(self):
fix_seeds()
paddle.set_flags({'FLAGS_enable_auto_parallel_align_mode': True})
global_mesh = paddle.distributed.ProcessMesh(
[[0, 2], [1, 3]], dim_names=["pp", "dp"]
)
fleet.auto.set_mesh(global_mesh)
self.model = PP_DP_MyModel()
pp_mesh0 = paddle.distributed.ProcessMesh([0, 2], dim_names=["dp"])
pp_mesh1 = paddle.distributed.ProcessMesh([1, 3], dim_names=["dp"])
dp_pp_pleacement = [dist.Shard(0)]
pp_group_1 = paddle.distributed.new_group([0, 1])
pp_group_2 = paddle.distributed.new_group([2, 3])
dp_group = paddle.distributed.new_group([1, 3])
self.micro_batches = 4
if self.rank < 2:
self.stage = PipelineStage(
self.model, self.rank % 2, 2, group=pp_group_1
)
else:
self.stage = PipelineStage(
self.model, self.rank % 2, 2, group=pp_group_2
)
self.stage.has_backward = True
loss_fn_ = nn.MSELoss()
schedule = ScheduleFThenB(
self.stage, self.micro_batches, loss_fn=loss_fn_
)
opt = paddle.optimizer.AdamW(
learning_rate=0.001, parameters=self.model.parameters()
)
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
loader = DataLoader(dataset, batch_size=8)
losses_by_step = []
all_losses_in_one_step_md5sum = []
num_iterations = 20
for iter_idx in range(num_iterations):
losses_by_micro_batch = []
for i, (data, label) in enumerate(loader):
dist_data = dist.shard_tensor(data, pp_mesh0, dp_pp_pleacement)
dist_label = dist.shard_tensor(
label, pp_mesh1, dp_pp_pleacement
)
schedule.step(
dist_data, target=dist_label, losses=losses_by_micro_batch
)
# Losses from two dp paths are in Partial(AVG) state, need to do all_reduce
if self.rank == 1 or self.rank == 3:
reduced_losses = []
for item in losses_by_micro_batch:
local_loss = item._local_value()
dist.all_reduce(
local_loss, op=dist.ReduceOp.AVG, group=dp_group
)
reduced_losses.append(local_loss)
if iter_idx == 0:
all_losses_in_one_step_md5sum.append(
local_loss._md5sum()
)

if self.rank == 3:
# Calculate mean using reduced losses
losses_by_step.append(
np.array(reduced_losses, dtype=np.float32).mean()
)
opt.step()
opt.clear_grad()
return losses_by_step, all_losses_in_one_step_md5sum

def run_test(self):
"""Compare losses between three training methods"""
self.setUpClass()
Expand All @@ -483,7 +557,10 @@ def run_test(self):
scheduleFThenB_with_ClipGradByGlobalNorm_losses = (
self.test_ScheduleFThenB_with_ClipGradByGlobalNorm()
)
dp_pp_losses = self.test_dp_pp()
dp_pp_losses, dp_pp_losses_md5sum = self.test_dp_pp()
dp_pp_align_mode_losses, dp_pp_align_mode_losses_md5sum = (
self.test_dp_pp_align_mode()
)

if self.rank == 3:
np.testing.assert_allclose(
Expand Down Expand Up @@ -516,6 +593,14 @@ def run_test(self):
rtol=1e-5,
)

np.testing.assert_allclose(
dp_pp_align_mode_losses,
dp_pp_losses,
rtol=1e-5,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喂入顺序完全一致时,是否第一个step的前向loss,两种情况下的md5需要完全相同?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

image image image


assert dp_pp_losses_md5sum == dp_pp_align_mode_losses_md5sum


if __name__ == '__main__':
Test_Schedules().run_test()