Skip to content

Commit bc43b5d

Browse files
wangxicodingAnnaTrainingG
authored andcommitted
[hybrid] out data parallel as optimizer sharding parallel (PaddlePaddle#35593)
1 parent eb373c9 commit bc43b5d

16 files changed

+967
-177
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ message ShardingConfig {
4343
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
4444
optional int32 pp_degree = 11 [ default = 1 ];
4545
optional bool optimize_cast = 12 [ default = false ];
46+
// Optimizer sharding. Temporary plans and may be deprecated
47+
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
4648
}
4749

4850
message HybridConfig {

paddle/fluid/operators/amp/check_finite_and_unscale_op.cc

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
2626
: OperatorWithKernel(type, inputs, outputs, attrs) {}
2727

2828
void InferShape(framework::InferShapeContext* ctx) const override {
29-
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
30-
"check_finite_and_unscale");
31-
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
32-
"check_finite_and_unscale");
33-
PADDLE_ENFORCE_EQ(
34-
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
35-
platform::errors::InvalidArgument(
36-
"The input(X) and output(Out) should have same size in "
37-
"Operator(check_finite_and_unscale), size of input(X) is %d "
38-
"and size of output(Out) is %d.",
39-
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
40-
auto x_dims = ctx->GetInputsDim("X");
41-
ctx->SetOutputsDim("Out", x_dims);
29+
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
30+
PADDLE_ENFORCE_EQ(
31+
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
32+
platform::errors::InvalidArgument(
33+
"The input(X) and output(Out) should have same size in "
34+
"Operator(check_finite_and_unscale), size of input(X) is %d "
35+
"and size of output(Out) is %d.",
36+
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
37+
auto x_dims = ctx->GetInputsDim("X");
38+
ctx->SetOutputsDim("Out", x_dims);
39+
}
4240
ctx->SetOutputDim("FoundInfinite", {1});
4341
}
4442

4543
protected:
4644
framework::OpKernelType GetExpectedKernelType(
4745
const framework::ExecutionContext& ctx) const override {
48-
return framework::OpKernelType(
49-
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
46+
auto dtype = framework::proto::VarType::FP32;
47+
if (ctx.MultiInputVar("X").size() >= 1) {
48+
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
49+
}
50+
return framework::OpKernelType(dtype, ctx.GetPlace());
5051
}
5152
};
5253

paddle/fluid/operators/amp/check_finite_and_unscale_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
9797
scale_data, inverse_scale_v, found_inf_data);
9898

9999
size_t xs_size = xs.size();
100+
if (xs_size == 0) return;
101+
100102
const auto& cpu_place = platform::CPUPlace();
101103
// calculate each tensor's start index and copy to device
102104
auto h_starts_tensor =

paddle/fluid/operators/amp/update_loss_scaling_op.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
2626
using framework::OperatorWithKernel::OperatorWithKernel;
2727

2828
void InferShape(framework::InferShapeContext* ctx) const override {
29-
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
3029
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
3130
"update_loss_scaling");
3231
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
@@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
3534
"update_loss_scaling");
3635
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
3736
"update_loss_scaling");
38-
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
39-
"update_loss_scaling");
4037
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
4138
"update_loss_scaling");
4239
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
4340
"update_loss_scaling");
4441
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
4542
"update_loss_scaling");
46-
auto x_dims = ctx->GetInputsDim("X");
47-
ctx->SetOutputsDim("Out", x_dims);
43+
44+
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
45+
PADDLE_ENFORCE_EQ(
46+
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
47+
platform::errors::InvalidArgument(
48+
"The input(X) and output(Out) should have same size in "
49+
"Operator(update_loss_scaling), size of input(X) is %d "
50+
"and size of output(Out) is %d.",
51+
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
52+
auto x_dims = ctx->GetInputsDim("X");
53+
ctx->SetOutputsDim("Out", x_dims);
54+
}
55+
4856
ctx->SetOutputDim("LossScaling", {1});
4957
ctx->SetOutputDim("OutGoodSteps", {1});
5058
ctx->SetOutputDim("OutBadSteps", {1});
@@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
5361
protected:
5462
framework::OpKernelType GetExpectedKernelType(
5563
const framework::ExecutionContext& ctx) const override {
56-
return framework::OpKernelType(
57-
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
64+
auto dtype = framework::proto::VarType::FP32;
65+
if (ctx.MultiInputVar("X").size() >= 1) {
66+
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
67+
}
68+
69+
return framework::OpKernelType(dtype, ctx.GetPlace());
5870
}
5971
};
6072

paddle/fluid/operators/amp/update_loss_scaling_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> {
9595
const std::vector<const framework::Tensor*>& xs,
9696
const std::vector<framework::Tensor*>& outs) const {
9797
size_t xs_size = xs.size();
98+
if (xs_size == 0) return;
99+
98100
const auto& cpu_place = platform::CPUPlace();
99101
// alloc each tensor's start index and copy to device
100102
auto h_in_starts_mem =

python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
105105
if op.type == "update_loss_scaling":
106106
update_loss_scaling_op_idx = idx
107107
inf_var_name = op.desc.input('FoundInfinite')[0]
108-
op._rename_input(inf_var_name, inf_var_name + "@sharding")
109108
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
110109
reversed_x = []
111110
reversed_x_paramname = []
@@ -142,10 +141,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
142141
name=inf_var_name + "@cast_int32",
143142
shape=inf_var.shape,
144143
dtype=core.VarDesc.VarType.INT32)
145-
inf_var_sharding = block.create_var(
146-
name=inf_var_name + "@sharding",
147-
shape=inf_var.shape,
148-
dtype=inf_var.dtype)
149144

150145
block._insert_op_without_sync(
151146
update_loss_scaling_op_idx,
@@ -179,10 +174,10 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
179174
update_loss_scaling_op_idx,
180175
type='cast',
181176
inputs={'X': inf_var_int32},
182-
outputs={'Out': inf_var_sharding},
177+
outputs={'Out': inf_var},
183178
attrs={
184179
"in_dtype": inf_var_int32.dtype,
185-
"out_dtype": inf_var_sharding.dtype,
180+
"out_dtype": inf_var.dtype,
186181
OP_ROLE_KEY: OpRole.Optimize
187182
})
188183
update_loss_scaling_op_idx += 1
@@ -210,10 +205,6 @@ def sync_amp_check_nan_inf(block, ring_ids):
210205
name=inf_var_name + "@cast_int32",
211206
shape=inf_var.shape,
212207
dtype=core.VarDesc.VarType.INT32)
213-
inf_var_global = block.create_var(
214-
name=inf_var_name + "@GLOBAL_WORLD",
215-
shape=inf_var.shape,
216-
dtype=inf_var.dtype)
217208
block._insert_op_without_sync(
218209
update_loss_scaling_op_idx,
219210
type='cast',

python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def prune_gradient_clip(self, block, shard, ring_ids):
3939
if not self._is_gradient_clip_op(op):
4040
continue
4141
if op.type == "sum":
42+
global_norm_sum_op_idx = idx
4243
continue
4344
deperate_op = False
4445
for input_name in op.desc.input_arg_names():
@@ -61,7 +62,10 @@ def prune_gradient_clip(self, block, shard, ring_ids):
6162
if output_name not in op.desc.input_arg_names():
6263
deperated_vars.add(output_name)
6364

64-
if not deperated_vars:
65+
# NOTE(wangxi): If only have 2 sharding, and 1 param.
66+
# sharding 0 will not deperated_vars, will return, only
67+
# sharding 1 will insert allreduce, then hang.
68+
if not deperated_vars and global_norm_sum_op_idx == -1:
6569
# got no gradient_clip op
6670
return
6771

@@ -71,8 +75,8 @@ def prune_gradient_clip(self, block, shard, ring_ids):
7175
if idx in deperate_op_idx:
7276
block._remove_op(idx, sync=False)
7377
continue
74-
reversed_inputs = []
7578
if op.type == "sum":
79+
reversed_inputs = []
7680
global_norm_sum_op_idx = idx
7781
for input_name in op.desc.input_arg_names():
7882
if input_name not in deperated_vars:
@@ -82,6 +86,28 @@ def prune_gradient_clip(self, block, shard, ring_ids):
8286
assert (len(op.desc.output_arg_names()) == 1)
8387
sum_res = op.desc.output_arg_names()[0]
8488

89+
# NOTE(wangxi): If we have 2 param, but sharding is 4,
90+
# then the sum op in some cards will not have input.
91+
# So we use fill_constant_op to set `sum_var` to zero,
92+
# which does not affect correctness.
93+
if len(reversed_inputs) == 0:
94+
sum_var = block.var(sum_res)
95+
namescope = op.attr("op_namescope")
96+
97+
block._remove_op(idx, sync=False)
98+
op = block._insert_op_without_sync(
99+
idx,
100+
type='fill_constant',
101+
inputs={},
102+
outputs={'Out': sum_res},
103+
attrs={
104+
'shape': sum_var.shape,
105+
'dtype': sum_var.dtype,
106+
'value': 0.0,
107+
OP_ROLE_KEY: OpRole.Optimize
108+
})
109+
op._set_attr('op_namescope', namescope)
110+
85111
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
86112
idx_offset = 1
87113
for ring_id in ring_ids:

python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,28 @@ def crop_output_var_from_op(self, op_idx, var_name):
117117
var_name] == []:
118118
self._block._remove_var(var_name, sync=False)
119119

120-
def remove_op(self, op_idx):
120+
def remove_op(self, op_idx, reserved_vars=None):
121121
# update deps
122122
op = self._block.ops[op_idx]
123123
for input_name in op.desc.input_arg_names():
124+
if reserved_vars is not None and input_name in reserved_vars:
125+
continue
124126
self.crop_input_var_from_op(op_idx, input_name)
125127
for output_name in op.desc.output_arg_names():
128+
if reserved_vars is not None and output_name in reserved_vars:
129+
continue
126130
self.crop_output_var_from_op(op_idx, output_name)
127131
self._block._remove_op(op_idx, sync=False)
128132

129133
def should_remove_op(self, op_idx):
130134
op = self._block.ops[op_idx]
131-
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
132-
# remove check_finite_and_unscale op if its input 'X' is empty
133-
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
134-
return True
135+
136+
# NOTE: At present, it is found that the OP without output is
137+
# only send_v2 and partial_send op, which will be used in
138+
# all device
139+
if len(op.desc.output_arg_names()) == 0:
140+
return False
141+
135142
for output_name in op.desc.output_arg_names():
136143
if output_name not in self._should_removed_var:
137144
return False

python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(self, ):
2424
self.global_params = set([])
2525
self.worker_idx = -1
2626
self.worker_num = -1
27-
self.global_param2device = {}
27+
self.global_param2device = dict()
28+
self.device2global_params = dict()
2829

2930
def setup(self, params_grads, worker_idx, worker_num):
3031
# param names of all devices
@@ -33,8 +34,9 @@ def setup(self, params_grads, worker_idx, worker_num):
3334
self.worker_idx = worker_idx
3435
self.worker_num = worker_num
3536
# global_param2device contains fp32 params and fp16 params
36-
self.global_param2device = self._split_params(params_grads, worker_idx,
37-
worker_num)
37+
# device2global_params only contains fp32 params
38+
self.global_param2device, self.device2global_params \
39+
= self._split_params(params_grads, worker_idx, worker_num)
3840

3941
def has_param(self, var_name):
4042
return var_name in self.global_param2device and \
@@ -64,7 +66,7 @@ def _split_params(self, params_grads, worker_idx, worker_num):
6466
device2params[device_idx].append(param_name)
6567
param2device[param_name] = device_idx
6668
mem_accu += mem
67-
return param2device
69+
return param2device, device2params
6870

6971
def _var_device_id(self, var_name):
7072
if var_name in self.global_param2device:

0 commit comments

Comments
 (0)