Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ repos:

# | python/paddle/de.+

# | python/paddle/distributed/a.+
| python/paddle/distributed/a.+

# | python/paddle/distributed/[b-e].+

Expand Down Expand Up @@ -133,7 +133,7 @@ repos:

| python/paddle/de.+

| python/paddle/distributed/a.+
# | python/paddle/distributed/a.+

| python/paddle/distributed/[b-e].+

Expand Down
1 change: 0 additions & 1 deletion python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _parse_function_signature(
if func_def.args.defaults and len(func_def.args.defaults) > (
len(func_def.args.args) - len(func_def.args.defaults)
):

idx = count - (
len(func_def.args.args) - len(func_def.args.defaults)
)
Expand Down
104 changes: 57 additions & 47 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,18 @@ def shard_tensor(
stop_gradient = getattr(data, "stop_gradient", True)

if paddle.framework.in_pir_mode():
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."
assert isinstance(data, (type(None), pir.Value)), (
"input tensor is not pir value."
)
assert data.is_dense_tensor_type(), (
"shard_tensor() input data only supported dense tensor type right."
)
tensor = data
else:
if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
), "Get an uninitialized param with an unregistered init_func."
assert data._init_func is not None, (
"Get an uninitialized param with an unregistered init_func."
)
tensor = data
elif isinstance(data, paddle.Tensor) and dtype is None:
# if place is not equal, it is handled in paddle.Tensor()
Expand Down Expand Up @@ -620,7 +620,9 @@ def forward(
)
assert check_placements_equal(
global_placements, dist_tensor.placements
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
), (
f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
)
local_shape = _cal_local_shape(
dist_tensor.shape, global_mesh, global_placements
)
Expand Down Expand Up @@ -890,9 +892,9 @@ def reshard(
elif in_pir_mode():
return paddle._C_ops.reshard(dist_tensor, mesh, placements)
else:
assert isinstance(
dist_tensor, Variable
), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
assert isinstance(dist_tensor, Variable), (
f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
)
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
main_program = default_main_program()
default_dist_ctx = get_default_distributed_context()
Expand Down Expand Up @@ -1113,12 +1115,14 @@ def is_dist_tensor(tensor) -> bool:

class _ShardOptimizer(Optimizer):
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert optimizer is not None, (
"The argument `optimizer` cannot be empty."
)
assert isinstance(
optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD)
), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
), (
"`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
)

# self.target_block = (
# paddle.base.framework.default_main_program().global_block()
Expand Down Expand Up @@ -1146,7 +1150,9 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert isinstance(
self._shard_fn,
(_ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3),
), "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
), (
"shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
)

if isinstance(
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
Expand Down Expand Up @@ -1219,7 +1225,9 @@ def _set_and_check_sharding_prop_from_param(self):
else:
assert (
mesh.dim_size(self._sharding_axis) == self._sharding_degree
), "The sharding degree of all parameters must be equal currently."
), (
"The sharding degree of all parameters must be equal currently."
)

def _shard_accumulator(self, param):
# Note (luchang): Some models may have parameters whose first dimension is 1,
Expand Down Expand Up @@ -1988,9 +1996,9 @@ def shard_master_weight(
)
if isinstance(master_weight, pir.Value):
data_op = master_weight.get_defining_op()
assert (
data_op.name() == "pd_op.data"
), "The master weight must be a result of data op."
assert data_op.name() == "pd_op.data", (
"The master weight must be a result of data op."
)
dim_map, partial_status = to_dim_map(
placements, len(master_weight.shape)
)
Expand Down Expand Up @@ -3254,9 +3262,9 @@ def state_dict(
suffix = _get_suffix(param, fused_param)
if suffix is not None:
value = dist_state_dict[param]
assert (
value.is_dist()
), f"key {param} value:{value} is not a dist tensor."
assert value.is_dist(), (
f"key {param} value:{value} is not a dist tensor."
)
mesh = value.process_mesh
placements = value.placements
if "_pow_acc" in suffix:
Expand Down Expand Up @@ -3328,12 +3336,12 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
if not isinstance(local_tensor, paddle.Tensor):
local_tensor = paddle.Tensor(local_tensor)
assert isinstance(
local_tensor, paddle.Tensor
), f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
assert len(local_tensor.shape) == len(
dist_attr["dims_mapping"]
), f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
assert isinstance(local_tensor, paddle.Tensor), (
f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
)
assert len(local_tensor.shape) == len(dist_attr["dims_mapping"]), (
f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
)
global_shape = local_tensor.shape
mesh = ProcessMesh(
np.array(dist_attr["process_group"]).reshape(
Expand All @@ -3343,18 +3351,18 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
placements = to_placements(dist_attr["dims_mapping"], mesh)
dist_tensor = dtensor_from_local(local_tensor, mesh, placements)
assert (
dist_tensor._local_value().shape == local_tensor.shape
), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
assert dist_tensor._local_value().shape == local_tensor.shape, (
f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
)
paddle.assign(local_tensor, dist_tensor._local_value())
return dist_tensor

global_state_dict = {}
with paddle.base.dygraph.guard():
for var_name, tensor in local_state_dict.items():
assert (
var_name in dist_attrs
), f"var {var_name} not in dist attrs:{dist_attrs}."
assert var_name in dist_attrs, (
f"var {var_name} not in dist attrs:{dist_attrs}."
)
global_state_dict[var_name] = build_distributed_tensor(
tensor, dist_attrs[var_name]
)
Expand Down Expand Up @@ -3386,7 +3394,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
k
].process_mesh or check_placements_equal(
v.placements, cur_v.placements
), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
), (
f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
)
param_name = (
self._structured_to_parameter_name[k]
if k in self._structured_to_parameter_name
Expand Down Expand Up @@ -3472,9 +3482,9 @@ def _get_shard_stage1_optimizer(self):
):
optimizer = optimizer._optimizer

assert isinstance(
optimizer, ShardingOptimizerStage1
), "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
assert isinstance(optimizer, ShardingOptimizerStage1), (
"The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
)

return optimizer

Expand All @@ -3485,9 +3495,9 @@ def _convert_state_dict_tensor_fusion(self, state_dict, optimizer_function):
else False
)

assert (
enable_tensor_fusion
), "Can only convert state_dict when tensor fusion is enabled."
assert enable_tensor_fusion, (
"Can only convert state_dict when tensor fusion is enabled."
)
optimizer = self._get_shard_stage1_optimizer()
assert optimizer is not None, "The optimizer should not be None."

Expand Down Expand Up @@ -3690,9 +3700,9 @@ def to_static(
# Deduce sharding degree for static
# Note: Because limitation of architecture, we need to ensure that
# all parameters are sharded by the same mesh axis
assert (
sharding_degree is not None
), "Sharding degree can not be None."
assert sharding_degree is not None, (
"Sharding degree can not be None."
)

if isinstance(shard_fn, ShardingStage1):
strategy.sharding.enable = True
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/auto_dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

def _fake_replicate_grad_to_partial(grad, partial_axis):
new_placements = grad.placements
assert (
new_placements[partial_axis] == dist.Replicate()
), "when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
assert new_placements[partial_axis] == dist.Replicate(), (
"when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
)

new_placements[partial_axis] = dist.Partial(dist.ReduceType.kRedSum)

Expand Down
40 changes: 21 additions & 19 deletions python/paddle/distributed/auto_parallel/high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self):

def cost_model(matched_programs, device_num, node_num):
# TODO(jeff41404): multi-node will be supported later
assert (
node_num == 1
), "we only support single node now, multi-node will be supported later"
assert node_num == 1, (
"we only support single node now, multi-node will be supported later"
)

# TODO(jeff41404): will evaluate the best combination of parallel strategies
# based on cost_model and return global_mesh, currently using pre-defined parallel strategy
Expand Down Expand Up @@ -224,7 +224,9 @@ def record_program_ops_post_hook(layer, inputs, outputs):
assert (
layer._op_recorder.start >= 0
and layer._op_recorder.is_valid is True
), f"{layer._full_name} has not recorded the start of the corresponding ops before"
), (
f"{layer._full_name} has not recorded the start of the corresponding ops before"
)
end = len(default_main_program().global_block().ops)
# some layers, such as rotary_embedding, will not add new ops to program
# assert end > layer._op_recorder.start, f"{layer._full_name} has not added new ops to the program"
Expand Down Expand Up @@ -754,19 +756,19 @@ def to_distributed(
for pattern_name, matched_patterns in results.items():
# process one pattern
pattern_ops_dist_infos = get_pattern(pattern_name).ops_dist_infos
assert (
pattern_ops_dist_infos is not None
), f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
assert pattern_ops_dist_infos is not None, (
f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
)
processed_patterns = []
for matched_pattern in matched_patterns:
# convert pattern_ops_dist_infos to program_ops_dist_infos
program_ops_dist_infos = {}
for pattern_ops_id, op_dist_info in pattern_ops_dist_infos.items():
program_ops_id = []
for pattern_op_id in pattern_ops_id:
assert (
pattern_op_id in matched_pattern.keys()
), f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
assert pattern_op_id in matched_pattern.keys(), (
f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
)
program_op_id = matched_pattern[pattern_op_id]
program_ops_id.append(program_op_id)
program_ops_dist_infos[tuple(program_ops_id)] = op_dist_info
Expand All @@ -789,9 +791,9 @@ def to_distributed(
if with_mp:
num_hidden_layers = len(matched_programs[DECODER_LAYER_NAME])
for pattern_name, processed_patterns in matched_programs.items():
assert (
len(processed_patterns) == num_hidden_layers
), "transformer patterns matched are incomplete"
assert len(processed_patterns) == num_hidden_layers, (
"transformer patterns matched are incomplete"
)
for idx, processed_pattern in enumerate(processed_patterns):
local_mesh = mesh
if with_pp:
Expand All @@ -801,9 +803,9 @@ def to_distributed(
local_mesh = mesh.get_mesh_with_dim("pp", pp_stage_id)

for program_ops_id, dist_infos in processed_pattern.items():
assert (
program_ops_id in ops_id_to_layer.keys()
), f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
assert program_ops_id in ops_id_to_layer.keys(), (
f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
)
dynamic_layer = ops_id_to_layer[program_ops_id]
mesh_num_dims = len(local_mesh.shape)
sharding_info = dist_infos.get_dist_info(mesh_num_dims)
Expand Down Expand Up @@ -832,9 +834,9 @@ def to_distributed(

if decoder_layers is not None:
num_decoder_blocks = len(decoder_layers)
assert (
num_decoder_blocks == num_hidden_layers
), f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
assert num_decoder_blocks == num_hidden_layers, (
f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
)

pp_degree = mesh.get_dim_size("pp")
num_blocks_per_stage = num_decoder_blocks // pp_degree
Expand Down
Loading
Loading