Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
Expand Down Expand Up @@ -333,26 +333,96 @@ def insert_allreduce_ops(block,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
use_calc_stream=False,
user_defined_strategy=None):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return

if user_defined_strategy and user_defined_strategy.fuse_all_reduce_ops:
insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars,
op_role, use_calc_stream,
user_defined_strategy.fuse_grad_size_in_MB)
else:
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})

return


def insert_fused_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False,
fuse_grad_size_in_MB=32):
segments = []
cur_size = 0.
last_dtype = None
for var in allreduce_vars:
real_var = block.var(var)
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_grad_size_in_MB \
or len(segments) == 0 \
or real_var.dtype != last_dtype:
segments.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
segments[-1].append(real_var)
cur_size += var_size

fused_vars = []
for segment in segments:
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)),
dtype=segment[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(tmp_var)
block._insert_op_without_sync(
insert_idx,
type="coalesce_tensor",
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": segment[0].dtype,
OP_ROLE_KEY: op_role
})

for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})

return
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})


def insert_reduce_ops(block,
Expand Down Expand Up @@ -528,7 +598,7 @@ def add_sync_comm(program, sharding_ring_id):
add the sync_comm op for the test prog.

"""
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.

assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def minimize_impl(self,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
use_calc_stream=True,
user_defined_strategy=self.user_defined_strategy)

# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
Expand Down Expand Up @@ -778,8 +779,12 @@ def _add_broadcast_allreduce(self, block):
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
# gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops(
Expand Down Expand Up @@ -896,8 +901,12 @@ def _add_broadcast_allreduce(self, block):
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(
block,
segment._start_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# gradient merge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,36 @@ def test_sharding_with_pp(self):

self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])

def test_sharding_dp_with_allreduce_fuse(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.1,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 2
self.optimizer(avg_cost, strategy, train_prog, startup_prog)

main_prog_ops = train_prog.global_block().ops
main_prog_op_types = [op.type for op in main_prog_ops]

assert 'c_allreduce_sum' in main_prog_op_types
assert 'coalesce_tensor' in main_prog_op_types

for op in main_prog_ops:
if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0]


if __name__ == "__main__":
unittest.main()