Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
14b07ba
Recompute fixed bug BOOL VarType
JZ-LIANG Jan 29, 2021
e6b489d
Sharding support Megatron
JZ-LIANG Jan 30, 2021
c8b0f92
sharding-megatron support amp, sharidng dp init broadcast
JZ-LIANG Feb 3, 2021
5444c3b
sharding-megatron suppoort gradclipbyglobalnorm
JZ-LIANG Feb 4, 2021
527bc96
Sharding allreduce --> reduce
JZ-LIANG Feb 19, 2021
7e35d31
sharding optimize init speed
JZ-LIANG Mar 5, 2021
2cc1b7c
recompute remove useless log
JZ-LIANG Mar 10, 2021
f273420
sharding: segment strategy
JZ-LIANG Mar 11, 2021
6a18b38
temp change for ernie_10b_two_branch
JZ-LIANG Mar 15, 2021
98baf20
sharding: gradient merge
JZ-LIANG Mar 16, 2021
9ece14f
sharding gradient merge: fix OOM
JZ-LIANG Mar 19, 2021
add91b7
sharding: revise save logic for gradient merge
JZ-LIANG Mar 19, 2021
e01e22a
Sharding: revise code format
JZ-LIANG Mar 26, 2021
ffb492b
sharding: update anchor segment strategy
JZ-LIANG Mar 26, 2021
0abe6e9
sharding: revise anchor segment logic
JZ-LIANG Mar 26, 2021
726525c
sharding: revise api
JZ-LIANG Mar 29, 2021
cb788cf
sharding: remove debug log
JZ-LIANG Mar 29, 2021
cf5b1c9
sharding: add sync in startup prog, uniform parallelism switch
JZ-LIANG Mar 29, 2021
ceae74b
sharding: update unitest
JZ-LIANG Mar 30, 2021
dde7d24
sharding: add more comments
JZ-LIANG Mar 31, 2021
60be6ec
recompute: fixed bug in create vars
JZ-LIANG Mar 31, 2021
b414ddf
sharding temp to check ci bug
JZ-LIANG Apr 1, 2021
620f138
sharding: revise comm _wait func
JZ-LIANG Apr 1, 2021
609859a
Merge remote-tracking branch 'upstream/develop' into sharding-ERNIE16…
JZ-LIANG Apr 1, 2021
0732442
sharding: revise comm init
JZ-LIANG Apr 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions paddle/fluid/framework/distributed_strategy.proto
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ message RecomputeConfig {
}

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add simple comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those args decide how sharding segment the program, which will affect the comm-calc overlap logic in sharding.
by now, we support two segment strategy:
"segment_broadcast_MB": segment by broadcast volume
"segment_anchors": segment by user defined anchors(op' s output)
I will add detail explanation in fluidDoc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, next PR can add comments or link in this code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded.

optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
optional int32 sharding_degree = 3 [ default = 8 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why set default = 8, maybe can set to -1 or 0, which means get value from world_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded. update in next pr

optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add all enum value in comments.
Can the strategy name be simplified, one is called by_broadcast(or by_param_size or other better names) and the other is by_anchors(or other better names ),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion! recorded

repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
}

message AMPConfig {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _init_wrapped_opt(self):
# computation by split the check_finite_and_unscale op.
is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding:
# if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't get it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding as well as sharidng-megatron do not support the pure_fp16 allreduce logic introduced in amp&purefp16 metaoptimizer, so we need to set "is_distributed" = False while sharding or megatron enable.
This PR will add logic for sharding-megatron supporting, but megatron meta-optimizer will be add in another pr, so I will remove this line for now.

# FIXME(wangxi). sharding failed when split check_finite_and_unscale
# FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior
is_distributed = False
self.wrapped_opt._set_distributed(is_distributed)

Expand Down
49 changes: 33 additions & 16 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def remove_cast_op(block, params, segment, offset):
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
Expand Down Expand Up @@ -103,6 +103,7 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
Expand All @@ -111,12 +112,24 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)

# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp \
check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_finite_and_unscale --> {op.type}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded. updated in next pr

should_check_param - to_check_param,
to_check_param - should_check_param)

if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
Expand All @@ -128,32 +141,36 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
# this allreduce communication should not overlap with calc
# insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})

comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
# comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# ring_id, [inf_var_int32])

block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_fp32},
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
Expand Down
56 changes: 38 additions & 18 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@


class GradientClipHelper(object):
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def __init__(self, mp_ring_id):
self.mp_ring_id = mp_ring_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments mp_ring is (sharding+mp)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded


def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")

def prune_gradient_clip(self, block, shard):
def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
Expand All @@ -44,6 +45,8 @@ def prune_gradient_clip(self, block, shard):
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)

if deperate_op:
deperate_op_idx.add(idx)
Expand All @@ -65,31 +68,48 @@ def prune_gradient_clip(self, block, shard):
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)

op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})

# this allreduce should not overlap with calc and should be scheduled in calc stream
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})

# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better before c_allreduce_sum, for c_allreduce_sum may out inf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since grad tend to be a decimal value, i think scale after allreduce would be better to avoid ”Arithmetic underflow“.

inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})

# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(set(
[param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale \
checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)

for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
Expand Down
Loading