Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ message RecomputeConfig {
}

message ShardingConfig {
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
optional string sharding_segment_strategy = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enum comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded, document will be added in fluiddoc and fleetx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need add comments to this code.

[ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
optional float segment_broadcast_MB = 2 [ default = 32.0 ];
repeated string segment_anchors = 3;
optional int32 sharding_degree = 4 [ default = 8 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 dp_degree = 6 [ default = 1 ];
optional bool hybrid_dp = 7 [ default = false ];
optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments, in 3d or 4d parallel, allreduce_in_optimize=True can reduce communication, allreduce_in_optimize=False can reduce memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded, document will be added in fluiddoc and fleetx and .py file where the feature is called.

but I think this should be a feature for internal project now, and we should not expose It to users ?

optional int32 pp_degree = 11 [ default = 1 ];
}

message AMPConfig {
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode']
self.use_sharding = user_defined_strategy.sharding

def _can_apply(self):
if not self.role_maker._is_collective:
return False

# FIXME revise for hybrid parallelism
if self.use_sharding:
return False

if self.user_defined_strategy.pipeline == True:
return True
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
Expand All @@ -105,7 +108,11 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
Expand Down Expand Up @@ -169,3 +176,58 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()

# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
update_loss_scaling_op_idx = -1

for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")

# not use amp
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
global_norm_sum_op_idx = -1
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
Expand All @@ -41,7 +42,11 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
Expand All @@ -51,7 +56,8 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)

if not deperated_vars:
# got no gradient_clip op
Expand All @@ -65,6 +71,7 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
continue
reversed_inputs = []
if op.type == "sum":
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
Expand All @@ -86,20 +93,20 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
OP_ROLE_KEY: OpRole.Optimize,
})

# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})

# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
Expand All @@ -115,3 +122,45 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return

# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue

if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})

# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})

return
Loading