Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ message ShardingConfig {
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not new a config call stage, and allow two value: stage=1 and stage=3 by now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update later

}

message HybridConfig {
Expand Down
31 changes: 16 additions & 15 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
"check_finite_and_unscale");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"check_finite_and_unscale");
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why supports op without input/output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("FoundInfinite", {1});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data);

size_t xs_size = xs.size();
if (xs_size == 0) return;

const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
Expand Down
26 changes: 19 additions & 7 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
Expand All @@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
"update_loss_scaling");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);

if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}

ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1});
Expand All @@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}

return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
size_t xs_size = xs.size();
if (xs_size == 0) return;

const auto& cpu_place = platform::CPUPlace();
// alloc each tensor's start index and copy to device
auto h_in_starts_mem =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
Expand Down Expand Up @@ -142,10 +141,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)

block._insert_op_without_sync(
update_loss_scaling_op_idx,
Expand Down Expand Up @@ -179,10 +174,10 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
outputs={'Out': inf_var},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
Expand Down Expand Up @@ -210,10 +205,6 @@ def sync_amp_check_nan_inf(block, ring_ids):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
global_norm_sum_op_idx = idx
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
Expand All @@ -61,7 +62,10 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)

if not deperated_vars:
# NOTE(wangxi): If only have 2 sharding, and 1 param.
# sharding 0 will not deperated_vars, will return, only
# sharding 1 will insert allreduce, then hang.
if not deperated_vars and global_norm_sum_op_idx == -1:
# got no gradient_clip op
return

Expand All @@ -71,8 +75,8 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
reversed_inputs = []
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
Expand All @@ -82,6 +86,28 @@ def prune_gradient_clip(self, block, shard, ring_ids):
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]

# NOTE(wangxi): If we have 2 param, but sharding is 4,
# then the sum op in some cards will not have input.
# So we use fill_constant_op to set `sum_var` to zero,
# which does not affect correctness.
if len(reversed_inputs) == 0:
sum_var = block.var(sum_res)
namescope = op.attr("op_namescope")

block._remove_op(idx, sync=False)
op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_res},
attrs={
'shape': sum_var.shape,
'dtype': sum_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
op._set_attr('op_namescope', namescope)

# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1
for ring_id in ring_ids:
Expand Down
17 changes: 12 additions & 5 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,28 @@ def crop_output_var_from_op(self, op_idx, var_name):
var_name] == []:
self._block._remove_var(var_name, sync=False)

def remove_op(self, op_idx):
def remove_op(self, op_idx, reserved_vars=None):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if reserved_vars is not None and input_name in reserved_vars:
continue
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
if reserved_vars is not None and output_name in reserved_vars:
continue
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)

def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True

# NOTE: At present, it is found that the OP without output is
# only send_v2 and partial_send op, which will be used in
# all device
if len(op.desc.output_arg_names()) == 0:
return False

for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, ):
self.global_params = set([])
self.worker_idx = -1
self.worker_num = -1
self.global_param2device = {}
self.global_param2device = dict()
self.device2global_params = dict()

def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices
Expand All @@ -33,8 +34,9 @@ def setup(self, params_grads, worker_idx, worker_num):
self.worker_idx = worker_idx
self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx,
worker_num)
# device2global_params only contains fp32 params
self.global_param2device, self.device2global_params \
= self._split_params(params_grads, worker_idx, worker_num)

def has_param(self, var_name):
return var_name in self.global_param2device and \
Expand Down Expand Up @@ -64,7 +66,7 @@ def _split_params(self, params_grads, worker_idx, worker_num):
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
return param2device
return param2device, device2params

def _var_device_id(self, var_name):
if var_name in self.global_param2device:
Expand Down
Loading