Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 146 additions & 43 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,30 +1057,51 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
dist_op = self._dist_context.get_dist_op_for_program(op)
dist_op.dist_attr.chunk_id = chunk_id
for name in op.input_arg_names + op.output_arg_names:
var = block._find_var_recursive(name)
if "lod_tensor_blocking_queue" in name:
continue
if name not in var_to_chunk_id:
op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(op)
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(var)
if (
dist_op.dist_attr.process_mesh
== dist_tensor.dist_attr.process_mesh
):
dist_tensor.dist_attr.chunk_id = chunk_id
var_to_chunk_id[var.name] = chunk_id

def set_process_mesh(block, op, process_mesh, var_to_process_mesh):
dist_op = self._dist_context.get_dist_op_for_program(op)
for name in op.input_arg_names:
if name not in var_to_process_mesh:
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
if (
op_dist_attr.process_mesh
== tensor_dist_attr.process_mesh
dist_op.dist_attr.process_mesh
== dist_tensor.dist_attr.process_mesh
):
tensor_dist_attr.chunk_id = op_dist_attr.chunk_id
var_to_chunk_id[var.name] = op_dist_attr.chunk_id
dist_tensor.dist_attr.process_mesh = process_mesh
var_to_process_mesh[var.name] = process_mesh
for name in op.output_arg_names:
if name not in var_to_process_mesh:
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
dist_tensor.dist_attr.process_mesh = process_mesh
var_to_process_mesh[var.name] = process_mesh
dist_op.dist_attr.process_mesh = process_mesh

if (
not self._dist_context.strategy
or not self._dist_context.strategy.pipeline.enable
):
return

pp_degree = get_pp_degree(self._dist_context)
pp_degree, sub_process_meshes = get_pp_degree(self._dist_context)
vpp_degree = self._dist_context.strategy.pipeline.vpp_degree
seg_method = self._dist_context.strategy.pipeline.vpp_seg_method
schedule_mode = self._dist_context.strategy.pipeline.schedule_mode
Expand All @@ -1099,8 +1120,11 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
block = serial_main_program.global_block()
ops = block.ops

# 1. search seg_method in op's struct_name, and get all ops of segments
seg_op_deps = collections.OrderedDict()
# Step1: search seg_method in op's struct_name
# 1. get op_idx of each segment
# 2. get process_mesh or each segment
seg_op_deps = collections.OrderedDict() # struct_name -> [idx]
seg_op_mesh = collections.OrderedDict() # struct_name -> process_mesh
regex = re.compile(seg_method, re.IGNORECASE)
for i, op in enumerate(ops):
struct_name = op.struct_name
Expand All @@ -1109,59 +1133,93 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
continue

struct_name = struct_name[m.start(0) :].split("/")[0]
dist_op = self._dist_context.get_dist_op_for_program(op)
if struct_name not in seg_op_deps:
seg_op_deps[struct_name] = [i]
seg_op_mesh[struct_name] = dist_op.dist_attr.process_mesh
else:
assert (
seg_op_deps[struct_name][-1] + 1 == i
), "The segment's ops should be continuous."
pre_op = ops[seg_op_deps[struct_name][-1]]
pre_dist_op = self._dist_context.get_dist_op_for_program(pre_op)
dist_op = self._dist_context.get_dist_op_for_program(op)
pre_mesh = seg_op_mesh[struct_name]
assert (
pre_dist_op.dist_attr.process_mesh
== dist_op.dist_attr.process_mesh
pre_mesh == dist_op.dist_attr.process_mesh
), "The segment's ops should have same process_mesh."
seg_op_deps[struct_name].extend([i])

# the num of chunk is equal to vpp_degree
num_parts = pp_degree * vpp_degree
num_chunks = pp_degree * vpp_degree
assert (
len(seg_op_deps.keys()) % num_parts == 0
), "number of layers[{}] ({}) should be devided by part number ({}).".format(
seg_method, len(seg_op_deps.keys()), num_parts
len(seg_op_deps) % num_chunks == 0
), "The number of layers[{}] ({}) should be devided by part number ({}).".format(
seg_method, len(seg_op_deps), num_chunks
)

part_size = len(seg_op_deps.keys()) // vpp_degree
# Step2: analysis whether the pp_stage is non-decreasing among segments
# 1. if non_decreasing is True, the ops' process_mesh will be changed by vpp strategy
# 2. if non_decreasing is False, the ops's process_mesh will not be changed.
non_decreasing = True
seg_pp_stages = [-1]
for seg_pm in seg_op_mesh.values():
assert seg_pm in sub_process_meshes
pp_stage = sub_process_meshes.index(seg_pm)
if seg_pp_stages[-1] > pp_stage:
non_decreasing = False
break
seg_pp_stages.append(pp_stage)

# 2. get boundary index of each chunk
results = [0] * (vpp_degree + 1)
memory_counter = 0
result_idx = 1
for struct_name, idxs in seg_op_deps.items():
if not non_decreasing:
_logger.info("Cannot Use Auto VPP")
else:
_logger.info("Using Auto VPP")

# Step3: Get op index boundary, pp_stage, chunk_id, struct_names of each segment
seg_pp_stages = [i % pp_degree for i in range(num_chunks)]
seg_chunk_ids = [i // pp_degree for i in range(num_chunks)]
part_size = len(seg_op_deps) // num_chunks
segment_struct_names = []
segment_parts = [0] * (num_chunks + 1)
memory_counter, seg_idx = 0, 1
struct_name = []
for name, idxs in seg_op_deps.items():
struct_name.append(name)
memory_counter += 1
if memory_counter == part_size:
results[result_idx] = idxs[-1] + 1
result_idx += 1
memory_counter = 0
results[vpp_degree] = len(ops)
segment_parts[seg_idx] = idxs[-1] + 1
memory_counter, seg_idx = 0, seg_idx + 1
segment_struct_names.append(struct_name)
struct_name = []
segment_parts[num_chunks] = len(ops)

# 3. set right chunk_id for each op
# Step4: set right chunk_id and process_mesh for each op and var
var_to_chunk_id = {}
for chunk_id in range(len(results) - 1):
start_idx = results[chunk_id]
end_idx = results[chunk_id + 1]
var_to_process_mesh = {}
for seg_id in range(len(segment_parts) - 1):
start_idx = segment_parts[seg_id]
end_idx = segment_parts[seg_id + 1]
pp_stage = seg_pp_stages[seg_id]
chunk_id = seg_chunk_ids[seg_id]
process_mesh = sub_process_meshes[pp_stage]
struct_names = segment_struct_names[seg_id]
seg_op_idx = []
for name in struct_names:
seg_op_idx.extend(seg_op_deps[name])

_logger.info(
"[chunk_{}] start op: [{}]: [{}] [{}]".format(
"stage=[{}], chunk_id=[{}], layer_name=[{}]".format(
pp_stage,
chunk_id,
struct_names,
)
)
_logger.info(
"start op: [{}]: [{}] [{}]".format(
ops[start_idx].type,
ops[start_idx].input_arg_names,
ops[start_idx].output_arg_names,
)
)
_logger.info(
"[chunk_{}] end op: [{}]: [{}] [{}]".format(
chunk_id,
"end op: [{}]: [{}] [{}]".format(
ops[end_idx - 1].type,
ops[end_idx - 1].input_arg_names,
ops[end_idx - 1].output_arg_names,
Expand All @@ -1173,9 +1231,28 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
if op.has_attr("sub_block"):
block_id = op.attr('sub_block').id
sub_block = serial_main_program.blocks[block_id]
for op in sub_block.ops:
set_chunk_id(sub_block, op, chunk_id, var_to_chunk_id)
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
block, op, process_mesh, var_to_process_mesh
)
set_chunk_id(block, op, chunk_id, var_to_chunk_id)

for sub_op in sub_block.ops:
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
sub_block,
sub_op,
process_mesh,
var_to_process_mesh,
)
set_chunk_id(
sub_block, sub_op, chunk_id, var_to_chunk_id
)
else:
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
block, op, process_mesh, var_to_process_mesh
)
set_chunk_id(block, op, chunk_id, var_to_chunk_id)

def _update_dist_attr_for_dp(self):
Expand Down Expand Up @@ -1915,8 +1992,34 @@ def infer_backward_op_partial_status(
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
grad_op_dist_attr.chunk_id = ref_fwd_chunk_id
# NOTE(zhaoyingli):
# The sum op is used to accmulate the grads' value of the same forward var,
# sum op's chunk_id is same with the last op which generate the grad.
ref_chunk_id = None
ref_process_mesh = None
for pre_idx in range(
idx - 1, first_backward_op_idx + 1, -1
):
pre_grad_op = ops[pre_idx]
inter_arg_name = list(
set(pre_grad_op.output_arg_names)
& set(grad_op.input_arg_names)
)
if len(inter_arg_name) > 0:
pre_op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(
pre_grad_op
)
)
ref_chunk_id = pre_op_dist_attr.chunk_id
ref_process_mesh = pre_op_dist_attr.process_mesh
break
assert (
ref_chunk_id is not None
and ref_process_mesh is not None
)
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.chunk_id = ref_chunk_id
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,7 +2335,7 @@ def get_pp_degree(dist_context):
for idx in reversed(global_pm_idx):
process_meshes.pop(idx)

return len(process_meshes)
return len(process_meshes), process_meshes


def get_pp_stage(dist_context, rank):
Expand Down
Loading