Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
b985745
add auto_parallel dir
Jun 28, 2021
b79e749
mv to paddle.distributed
Jun 28, 2021
1671850
add shard_xx api
Jul 1, 2021
ec55a43
add distributed attrs for var
Jul 8, 2021
25abc00
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 9, 2021
bf24fb7
add ut, test=develop
Jul 9, 2021
8ea9363
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 18, 2021
9e4b3d8
add dist
Jul 21, 2021
e65f77e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 22, 2021
8b95c1e
update
Jul 26, 2021
ccae6ae
update
Jul 26, 2021
d107751
update
Jul 27, 2021
f7e70ea
update
Jul 27, 2021
3111159
update
Jul 27, 2021
70cdb69
update, test=develop
Jul 27, 2021
9e5b0f0
update, test=develop
Jul 27, 2021
59936ef
update, test=develop
Jul 27, 2021
27ee413
update, test=develop
Jul 27, 2021
3a8ceef
update, test=develop
Jul 27, 2021
d11f317
update, test=develop
Jul 28, 2021
f5ef245
update, test=develop
Jul 28, 2021
7293b4f
update
Jul 28, 2021
1240edc
update
Jul 28, 2021
05455fb
update
Jul 28, 2021
3e1b3a0
update
Jul 28, 2021
8950c35
update
Jul 28, 2021
b94a9f2
update, test=develop
Jul 28, 2021
e121349
update, test=develop
Jul 28, 2021
fe51aa3
update
Jul 28, 2021
4563d42
update
Jul 28, 2021
192580d
Merge branch 'develop' into auto_parallel_basic
Jul 28, 2021
2e69980
delete unused proto
Jul 28, 2021
608dd3f
resotre op_desc
Jul 28, 2021
cb9b6bf
restore type_defs
Jul 28, 2021
8e6559e
update var_desc
Jul 28, 2021
00f5f4d
remove dimss_mapping for proto_pybind
Jul 28, 2021
1aa94da
update interface.py
Jul 28, 2021
97a446c
update framework.py
Jul 28, 2021
c586fc6
update
Jul 28, 2021
fc6cde9
update
Jul 29, 2021
9d1a664
add auto_parallel dir
Jun 28, 2021
5d1b472
mv to paddle.distributed
Jun 28, 2021
d1aabad
add shard_xx api
Jul 1, 2021
e6ba855
add distributed attrs for var
Jul 8, 2021
3bf613c
add ut, test=develop
Jul 9, 2021
8942a99
[WIP] Add the auto completion feature and related codes
aoyulong Jul 16, 2021
6916cf2
[WIP] Improve the auto completion and related codes
aoyulong Jul 18, 2021
cafdd18
[WIP] Make the auto completion to support data-parallel
aoyulong Jul 19, 2021
4d6dd52
[WIP] Make the completion support mp and dp+mp
aoyulong Jul 19, 2021
3f05d09
[WIP] Refactor auto completion unit test for MLP
aoyulong Jul 20, 2021
2c56e12
[WIP] Refactor the implementation of DistributedOperatorImpl
aoyulong Jul 21, 2021
a83e9cd
[WIP] Improve dims_mapping update rule and fix a bug
aoyulong Jul 21, 2021
203ea14
[WIP] Support auto completion for one transformer decoder layer
aoyulong Jul 21, 2021
bbc2c39
[WIP] Add a minor change
aoyulong Jul 21, 2021
2b6f992
[WIP] Fix a bug within the uint test
aoyulong Jul 22, 2021
921c53d
Shard XShape tensor, add embedding completion and refactor code
aoyulong Jul 27, 2021
a03d503
Add the distributed_operators dir to setup.py.in
aoyulong Jul 28, 2021
3770f13
Improve the completion process and add the unittest for gpt
aoyulong Jul 29, 2021
967d0e7
fix process_mesh ut
Jul 29, 2021
cd1e390
fix process_mesh ut
Jul 29, 2021
f48ec91
update
Jul 29, 2021
b07affa
update, test=develop
Jul 30, 2021
f304b47
Add support for automatically completing distributed attrs of special…
aoyulong Jul 30, 2021
a00fe9e
update
Jul 30, 2021
da9fe30
update
Aug 2, 2021
3daecf2
update
Aug 2, 2021
5640879
fix doc sample codes, test=develop
Aug 2, 2021
05b0f82
improve coverage, test=develop
Aug 2, 2021
fe93d0e
add static_mode check, test=develop
Aug 2, 2021
033c541
Model the cluster for cost model and physical mapping
aoyulong Aug 4, 2021
9856d47
update, test=develop
Aug 4, 2021
890c70c
add set_placement, test=develop
Aug 5, 2021
6291697
Add the check to make sure the candidate tensors' size is great than …
aoyulong Aug 5, 2021
4b90b03
update doc, test=develop
Aug 5, 2021
c395b84
update doc, test=develop
Aug 5, 2021
8390e01
update doc, test=develop
Aug 5, 2021
f7d5631
update doc, test=develop
Aug 6, 2021
3a2666e
update, test=develop
Aug 6, 2021
fa98e39
Auto mark dist attrs annotated by user
aoyulong Aug 9, 2021
b5b8b9b
Merge branch 'PaddlePaddle:develop' into develop
aoyulong Aug 9, 2021
70bc589
Merge branch 'PaddlePaddle:develop' into develop
aoyulong Aug 9, 2021
b9bd421
Merge PR#33804
aoyulong Aug 9, 2021
b59bc33
Merge branch 'PaddlePaddle:develop' into develop
aoyulong Aug 9, 2021
773516b
update ndarray to nested list, test=develop
Aug 10, 2021
685504f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Aug 10, 2021
632eeac
Merge branch 'PaddlePaddle:develop' into develop
aoyulong Aug 10, 2021
87abb4b
Merge branch 'pr_33804' into auto_parallel
aoyulong Aug 10, 2021
7ac6299
update, test=develop
Aug 10, 2021
c724593
Add auto-completion module for auto-parallel (based on PR#33804)
aoyulong Aug 11, 2021
63e66bc
Merge branch 'pr_33804' into auto_parallel
aoyulong Aug 11, 2021
7087b1e
Merge branch 'PaddlePaddle:develop' into develop
aoyulong Aug 11, 2021
1908acf
Merge branch 'develop' of https://github.com/aoyulong/Paddle into aut…
aoyulong Aug 11, 2021
86ccd47
Remove unnecessary files
aoyulong Aug 11, 2021
3f7dca2
Remove unrelated files for the auto completion pr
aoyulong Aug 11, 2021
ed02152
Update the unit test to improve the coverage
aoyulong Aug 12, 2021
88e9e23
Modify codes based on reviews
aoyulong Aug 16, 2021
63a6ec6
Minor changes for CI
aoyulong Aug 17, 2021
6b77bc8
Improve some codes based on new comments
aoyulong Aug 17, 2021
411507d
Merge branch 'auto_parallel_completion' of https://github.com/aoyulon…
aoyulong Aug 19, 2021
2ef97ba
support shard reader
JZ-LIANG Aug 24, 2021
5993a30
support shard reader
JZ-LIANG Aug 24, 2021
a8a26de
add parallel mode
JZ-LIANG Aug 24, 2021
93348eb
update process mesh
JZ-LIANG Aug 24, 2021
e957043
add method to compute comm_group
JZ-LIANG Aug 24, 2021
291d1b7
implement dist_embedding forward func
JZ-LIANG Aug 24, 2021
c74ee5a
implement dist matmul forward func
JZ-LIANG Aug 24, 2021
4c00571
implement dist reshape forward func
JZ-LIANG Aug 24, 2021
b75ceca
add transpiler framework
JZ-LIANG Aug 24, 2021
67abec3
add transpiler forward
JZ-LIANG Aug 24, 2021
cd2526c
implement transpiler forward
JZ-LIANG Aug 24, 2021
d7d3b74
implement transpiler backward & update
JZ-LIANG Aug 24, 2021
52d054c
add process
JZ-LIANG Aug 24, 2021
1b8ddfb
add unitest
JZ-LIANG Aug 24, 2021
ae3e506
chmod
JZ-LIANG Aug 24, 2021
e2fa7cd
chmod
JZ-LIANG Aug 24, 2021
de53039
chmod
JZ-LIANG Aug 24, 2021
fbe3356
update unitest
JZ-LIANG Aug 24, 2021
d0798cb
add unitest for gpt
JZ-LIANG Aug 25, 2021
fbc42d6
remove unused print
JZ-LIANG Aug 25, 2021
f0f58dc
rename transpiler --> partitioner
JZ-LIANG Aug 25, 2021
f5cd926
rename transpiler --> partitioner
JZ-LIANG Aug 25, 2021
2ebece8
chmod
JZ-LIANG Aug 25, 2021
b22ea19
chmod
JZ-LIANG Aug 25, 2021
cc694b1
bug fixed
JZ-LIANG Aug 25, 2021
1cc96ca
remove amp function
JZ-LIANG Aug 26, 2021
4fb30ef
update case for dp mode
JZ-LIANG Aug 27, 2021
56ff62e
update case for dp mode
JZ-LIANG Aug 27, 2021
4ec9f80
Merge branch 'pr_35117' into auto_parallel_integration
aoyulong Aug 29, 2021
0cb34e2
[Auto Parallel] Integrate all parts with the newest code
aoyulong Aug 29, 2021
89b467f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Sep 2, 2021
d7286cb
Integrate all parts of auto parallel and improve codes
aoyulong Sep 6, 2021
4fb96e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Sep 6, 2021
00d699e
Modify distributed_strategy.proto to conform the main stream
aoyulong Sep 6, 2021
79f1025
Restore parts of distributed_strategy to conform the develop branch
aoyulong Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ message DistributedStrategy {
optional bool calc_comm_same_stream = 32 [ default = false ];
optional bool asp = 33 [ default = false ];
optional bool fuse_grad_merge = 34 [ default = false ];
optional bool semi_auto = 35 [ default = false ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand Down
170 changes: 143 additions & 27 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False
tensor_desc = tensor_node.var()
# Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER:
return False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None
Expand All @@ -263,6 +266,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
if pred_op_node.op().type() == "create_py_reader" \
or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
Expand All @@ -279,6 +286,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
if succ_op_node.op().type() == "create_py_reader" \
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
Expand All @@ -298,11 +309,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
# Skip reader op
op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read":
return False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
Expand Down Expand Up @@ -344,6 +362,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
Expand Down Expand Up @@ -400,47 +420,143 @@ def complete_annotation(program, dist_context=None):
if dist_context is None:
dist_context = get_default_distributed_context()

# Initialize distributed attributes for all var and op node in program
# Initialize distributed attributes for all var and op node in program
dist_context.initialize_distributed_attr_for_program(program)
# print_program_with_distributed_attr(program, dist_context)

# Convert program to graph
graph = framework.IrGraph(core.Graph(program.desc))

# Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph)

# # Complete process mesh for each node
# Complete process mesh for each node
all_nodes = list(graph.all_nodes())

def sort_key_fun(node):
first = -1
if node.is_op():
first = 0
else:
first = 1
second = -1
if node.is_op() and node.op() is not None:
second = node.op().id()
if node.is_var() and node.var() is not None:
second = node.var().id()
return (first, second)

all_nodes.sort(key=sort_key_fun)

reach_fix_point = False
while not reach_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
total_changed = False
reach_fwd_fix_point = False
reach_bwd_fix_point = False
while not reach_fwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
if changed:
reach_fwd_fix_point = False
total_changed = True
else:
reach_fwd_fix_point = True
while not reach_bwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_bwd_fix_point = False
total_changed = True
else:
reach_bwd_fix_point = True
if total_changed:
reach_fix_point = False
else:
reach_fix_point = True
# Validation the completion of process meshes and should be moved to a proper location
is_wrong = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
node)
if tensor_dist_attr.get_process_mesh() is None:
msg_str = ""
for op_node in node.inputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
for op_node in node.outputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format(
node.var().name(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
node)
if op_dist_attr.get_process_mesh() is None:
msg_str = ""
for tensor_node in node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
for tensor_node in node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format(
node.op().type(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is None:
print("op op is None", node.name())
if is_wrong:
assert False, "Cannot complete process_meshes of the program."

# Complete dims_mapping for each node
reach_fix_point = False
Expand Down
45 changes: 30 additions & 15 deletions python/paddle/distributed/auto_parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,15 @@ def initialize_distributed_attr_for_program(self, program):
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor.type == core.VarDesc.VarType.READER:
tensor_dist_attr.set_shape([])
else:
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
-1 for _ in range(len(tensor_dist_attr.get_shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
Expand All @@ -168,12 +171,18 @@ def initialize_distributed_attr_for_program(self, program):
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op.type == "create_py_reader" \
or tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_input_shape(tensor_name, [])
else:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
-1
for _ in range(
len(op_dist_attr.get_input_shape(tensor_name)))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
Expand All @@ -184,12 +193,18 @@ def initialize_distributed_attr_for_program(self, program):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_output_shape(tensor_name, [])
else:
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
-1
for _ in range(
len(
op_dist_attr.get_output_shape(tensor_name)))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
Expand Down Expand Up @@ -378,8 +393,8 @@ def amend_distributed_attr_for_program(self):
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[
i]] > tensor_shape[i]:
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1

for attr in self._op_distributed_attr_map_for_program.values():
Expand All @@ -392,8 +407,8 @@ def amend_distributed_attr_for_program(self):
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1

for arg_name in attr.get_owner_op().desc.output_arg_names():
Expand All @@ -403,8 +418,8 @@ def amend_distributed_attr_for_program(self):
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1

def _get_data_parallel_info(self):
Expand Down
Loading