From 7cc198086355eb4ad9b0ff9595bf92b1a371e18f Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 18 Aug 2025 21:47:59 +0800 Subject: [PATCH 1/2] [CodeStyle] `black -> ruff format` migration - part 26 --- .pre-commit-config.yaml | 4 +- python/paddle/_paddle_docs.py | 1 - .../paddle/distributed/auto_parallel/api.py | 104 +++++----- .../auto_parallel/auto_dp_utils.py | 6 +- .../auto_parallel/high_level_api.py | 40 ++-- .../distributed/auto_parallel/interface.py | 36 ++-- .../intermediate/context_parallel.py | 66 +++--- .../intermediate/parallel_base.py | 6 +- .../auto_parallel/intermediate/parallelize.py | 36 ++-- .../intermediate/pipeline_parallel.py | 48 ++--- .../intermediate/sharded_data_parallel.py | 12 +- .../intermediate/tensor_parallel.py | 24 +-- .../distributed/auto_parallel/local_layer.py | 12 +- .../distributed/auto_parallel/local_map.py | 18 +- .../distributed/auto_parallel/moe_utils.py | 24 +-- .../operators/dist_flash_attn.py | 6 +- .../auto_parallel/pipelining/_backward.py | 12 +- .../auto_parallel/pipelining/microbatch.py | 24 +-- .../auto_parallel/pipelining/schedules.py | 18 +- .../auto_parallel/pipelining/stage.py | 116 ++++++----- .../auto_parallel/pipelining/utils.py | 6 +- .../auto_parallel/placement_type.py | 6 +- .../distributed/auto_parallel/process_mesh.py | 36 ++-- .../distributed/auto_parallel/random.py | 24 +-- .../distributed/auto_parallel/sharding.py | 82 ++++---- .../auto_parallel/static/auto_align_tool.py | 6 +- .../auto_parallel/static/cluster_v2.py | 24 +-- .../auto_parallel/static/completion.py | 76 +++---- .../auto_parallel/static/cost/base_cost.py | 18 +- .../static/cost/estimate_cost.py | 8 +- .../static/cost/op_runtime_cost.py | 16 +- .../auto_parallel/static/cost_model.py | 6 +- .../auto_parallel/static/dist_context.py | 60 +++--- .../auto_parallel/static/dist_loader.py | 6 +- .../auto_parallel/static/dist_op.py | 20 +- .../auto_parallel/static/dist_tensor.py | 12 +- .../auto_parallel/static/engine.py | 118 +++++------ .../auto_parallel/static/helper.py | 36 ++-- .../auto_parallel/static/mapper.py | 18 +- .../auto_parallel/static/operators/common.py | 24 +-- .../dist_check_finite_and_unscale.py | 30 +-- .../static/operators/dist_concat.py | 6 +- .../static/operators/dist_cross_entropy.py | 124 +++++------ .../static/operators/dist_default.py | 24 +-- .../static/operators/dist_dropout.py | 14 +- .../static/operators/dist_eltwise.py | 12 +- .../static/operators/dist_embedding.py | 92 ++++----- .../static/operators/dist_flash_attn.py | 6 +- .../static/operators/dist_fused_attention.py | 12 +- .../operators/dist_fused_dropout_add.py | 6 +- .../operators/dist_fused_feedforward.py | 12 +- .../static/operators/dist_matmul.py | 116 +++++------ .../static/operators/dist_reduce_sum_p.py | 12 +- .../static/operators/dist_reshape.py | 24 +-- .../static/operators/dist_split.py | 24 +-- .../static/operators/dist_tile.py | 6 +- .../static/operators/dist_transpose.py | 6 +- .../operators/dist_update_loss_scaling.py | 70 +++---- .../auto_parallel/static/parallelizer.py | 12 +- .../auto_parallel/static/partitioner.py | 24 +-- .../auto_parallel/static/pir_pass.py | 96 +++++---- .../auto_parallel/static/planner.py | 45 ++-- .../auto_parallel/static/process_group.py | 12 +- .../auto_parallel/static/process_mesh_v2.py | 24 +-- .../auto_parallel/static/reshard.py | 42 ++-- .../reshard_funcs/nd_mesh_reshard_func.py | 6 +- .../reshard_funcs/p_to_r_reshard_func.py | 6 +- .../reshard_funcs/p_to_s_reshard_func.py | 6 +- .../reshard_funcs/r_to_s_reshard_func.py | 6 +- .../reshard_funcs/s_to_r_reshard_func.py | 6 +- .../reshard_funcs/same_status_reshard_func.py | 6 +- .../auto_parallel/static/tuner/algorithms.py | 6 +- .../static/tuner/optimization_tuner.py | 6 +- .../static/tuner/rule_based_tuner.py | 122 ++++++----- .../distributed/auto_parallel/static/utils.py | 192 +++++++++--------- .../paddle/distributed/auto_tuner/recorder.py | 6 +- .../paddle/distributed/auto_tuner/search.py | 12 +- python/paddle/distributed/auto_tuner/utils.py | 4 +- 78 files changed, 1240 insertions(+), 1210 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64aa9927963414..8aa4c1c4fa189c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,7 +77,7 @@ repos: # | python/paddle/de.+ - # | python/paddle/distributed/a.+ + | python/paddle/distributed/a.+ # | python/paddle/distributed/[b-e].+ @@ -133,7 +133,7 @@ repos: | python/paddle/de.+ - | python/paddle/distributed/a.+ + # | python/paddle/distributed/a.+ | python/paddle/distributed/[b-e].+ diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index a5b76559dce62d..9c54b16d825b8f 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -74,7 +74,6 @@ def _parse_function_signature( if func_def.args.defaults and len(func_def.args.defaults) > ( len(func_def.args.args) - len(func_def.args.defaults) ): - idx = count - ( len(func_def.args.args) - len(func_def.args.defaults) ) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 93effbdde8bad2..f9a96f2d0dc1ca 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -298,18 +298,18 @@ def shard_tensor( stop_gradient = getattr(data, "stop_gradient", True) if paddle.framework.in_pir_mode(): - assert isinstance( - data, (type(None), pir.Value) - ), "input tensor is not pir value." - assert ( - data.is_dense_tensor_type() - ), "shard_tensor() input data only supported dense tensor type right." + assert isinstance(data, (type(None), pir.Value)), ( + "input tensor is not pir value." + ) + assert data.is_dense_tensor_type(), ( + "shard_tensor() input data only supported dense tensor type right." + ) tensor = data else: if isinstance(data, EagerParamBase) and not data._is_initialized(): - assert ( - data._init_func is not None - ), "Get an uninitialized param with an unregistered init_func." + assert data._init_func is not None, ( + "Get an uninitialized param with an unregistered init_func." + ) tensor = data elif isinstance(data, paddle.Tensor) and dtype is None: # if place is not equal, it is handled in paddle.Tensor() @@ -620,7 +620,9 @@ def forward( ) assert check_placements_equal( global_placements, dist_tensor.placements - ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." + ), ( + f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." + ) local_shape = _cal_local_shape( dist_tensor.shape, global_mesh, global_placements ) @@ -890,9 +892,9 @@ def reshard( elif in_pir_mode(): return paddle._C_ops.reshard(dist_tensor, mesh, placements) else: - assert isinstance( - dist_tensor, Variable - ), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]" + assert isinstance(dist_tensor, Variable), ( + f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]" + ) sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim) main_program = default_main_program() default_dist_ctx = get_default_distributed_context() @@ -1113,12 +1115,14 @@ def is_dist_tensor(tensor) -> bool: class _ShardOptimizer(Optimizer): def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): - assert ( - optimizer is not None - ), "The argument `optimizer` cannot be empty." + assert optimizer is not None, ( + "The argument `optimizer` cannot be empty." + ) assert isinstance( optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD) - ), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now." + ), ( + "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now." + ) # self.target_block = ( # paddle.base.framework.default_main_program().global_block() @@ -1146,7 +1150,9 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): assert isinstance( self._shard_fn, (_ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3), - ), "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3" + ), ( + "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3" + ) if isinstance( self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3) @@ -1219,7 +1225,9 @@ def _set_and_check_sharding_prop_from_param(self): else: assert ( mesh.dim_size(self._sharding_axis) == self._sharding_degree - ), "The sharding degree of all parameters must be equal currently." + ), ( + "The sharding degree of all parameters must be equal currently." + ) def _shard_accumulator(self, param): # Note (luchang): Some models may have parameters whose first dimension is 1, @@ -1988,9 +1996,9 @@ def shard_master_weight( ) if isinstance(master_weight, pir.Value): data_op = master_weight.get_defining_op() - assert ( - data_op.name() == "pd_op.data" - ), "The master weight must be a result of data op." + assert data_op.name() == "pd_op.data", ( + "The master weight must be a result of data op." + ) dim_map, partial_status = to_dim_map( placements, len(master_weight.shape) ) @@ -3254,9 +3262,9 @@ def state_dict( suffix = _get_suffix(param, fused_param) if suffix is not None: value = dist_state_dict[param] - assert ( - value.is_dist() - ), f"key {param} value:{value} is not a dist tensor." + assert value.is_dist(), ( + f"key {param} value:{value} is not a dist tensor." + ) mesh = value.process_mesh placements = value.placements if "_pow_acc" in suffix: @@ -3328,12 +3336,12 @@ def build_distributed_tensor(local_tensor, dist_attr): ) if not isinstance(local_tensor, paddle.Tensor): local_tensor = paddle.Tensor(local_tensor) - assert isinstance( - local_tensor, paddle.Tensor - ), f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor." - assert len(local_tensor.shape) == len( - dist_attr["dims_mapping"] - ), f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}." + assert isinstance(local_tensor, paddle.Tensor), ( + f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor." + ) + assert len(local_tensor.shape) == len(dist_attr["dims_mapping"]), ( + f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}." + ) global_shape = local_tensor.shape mesh = ProcessMesh( np.array(dist_attr["process_group"]).reshape( @@ -3343,18 +3351,18 @@ def build_distributed_tensor(local_tensor, dist_attr): ) placements = to_placements(dist_attr["dims_mapping"], mesh) dist_tensor = dtensor_from_local(local_tensor, mesh, placements) - assert ( - dist_tensor._local_value().shape == local_tensor.shape - ), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}" + assert dist_tensor._local_value().shape == local_tensor.shape, ( + f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}" + ) paddle.assign(local_tensor, dist_tensor._local_value()) return dist_tensor global_state_dict = {} with paddle.base.dygraph.guard(): for var_name, tensor in local_state_dict.items(): - assert ( - var_name in dist_attrs - ), f"var {var_name} not in dist attrs:{dist_attrs}." + assert var_name in dist_attrs, ( + f"var {var_name} not in dist attrs:{dist_attrs}." + ) global_state_dict[var_name] = build_distributed_tensor( tensor, dist_attrs[var_name] ) @@ -3386,7 +3394,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None: k ].process_mesh or check_placements_equal( v.placements, cur_v.placements - ), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match" + ), ( + f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match" + ) param_name = ( self._structured_to_parameter_name[k] if k in self._structured_to_parameter_name @@ -3472,9 +3482,9 @@ def _get_shard_stage1_optimizer(self): ): optimizer = optimizer._optimizer - assert isinstance( - optimizer, ShardingOptimizerStage1 - ), "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled." + assert isinstance(optimizer, ShardingOptimizerStage1), ( + "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled." + ) return optimizer @@ -3485,9 +3495,9 @@ def _convert_state_dict_tensor_fusion(self, state_dict, optimizer_function): else False ) - assert ( - enable_tensor_fusion - ), "Can only convert state_dict when tensor fusion is enabled." + assert enable_tensor_fusion, ( + "Can only convert state_dict when tensor fusion is enabled." + ) optimizer = self._get_shard_stage1_optimizer() assert optimizer is not None, "The optimizer should not be None." @@ -3690,9 +3700,9 @@ def to_static( # Deduce sharding degree for static # Note: Because limitation of architecture, we need to ensure that # all parameters are sharded by the same mesh axis - assert ( - sharding_degree is not None - ), "Sharding degree can not be None." + assert sharding_degree is not None, ( + "Sharding degree can not be None." + ) if isinstance(shard_fn, ShardingStage1): strategy.sharding.enable = True diff --git a/python/paddle/distributed/auto_parallel/auto_dp_utils.py b/python/paddle/distributed/auto_parallel/auto_dp_utils.py index 6c2a9da0958a09..b53af6c6a374c1 100644 --- a/python/paddle/distributed/auto_parallel/auto_dp_utils.py +++ b/python/paddle/distributed/auto_parallel/auto_dp_utils.py @@ -21,9 +21,9 @@ def _fake_replicate_grad_to_partial(grad, partial_axis): new_placements = grad.placements - assert ( - new_placements[partial_axis] == dist.Replicate() - ), "when reshard fake replicated grad to partial, the partial axis of grad should be Replicate" + assert new_placements[partial_axis] == dist.Replicate(), ( + "when reshard fake replicated grad to partial, the partial axis of grad should be Replicate" + ) new_placements[partial_axis] = dist.Partial(dist.ReduceType.kRedSum) diff --git a/python/paddle/distributed/auto_parallel/high_level_api.py b/python/paddle/distributed/auto_parallel/high_level_api.py index 202e47512f2821..05742796bba597 100644 --- a/python/paddle/distributed/auto_parallel/high_level_api.py +++ b/python/paddle/distributed/auto_parallel/high_level_api.py @@ -34,9 +34,9 @@ def __init__(self): def cost_model(matched_programs, device_num, node_num): # TODO(jeff41404): multi-node will be supported later - assert ( - node_num == 1 - ), "we only support single node now, multi-node will be supported later" + assert node_num == 1, ( + "we only support single node now, multi-node will be supported later" + ) # TODO(jeff41404): will evaluate the best combination of parallel strategies # based on cost_model and return global_mesh, currently using pre-defined parallel strategy @@ -224,7 +224,9 @@ def record_program_ops_post_hook(layer, inputs, outputs): assert ( layer._op_recorder.start >= 0 and layer._op_recorder.is_valid is True - ), f"{layer._full_name} has not recorded the start of the corresponding ops before" + ), ( + f"{layer._full_name} has not recorded the start of the corresponding ops before" + ) end = len(default_main_program().global_block().ops) # some layers, such as rotary_embedding, will not add new ops to program # assert end > layer._op_recorder.start, f"{layer._full_name} has not added new ops to the program" @@ -754,9 +756,9 @@ def to_distributed( for pattern_name, matched_patterns in results.items(): # process one pattern pattern_ops_dist_infos = get_pattern(pattern_name).ops_dist_infos - assert ( - pattern_ops_dist_infos is not None - ), f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check" + assert pattern_ops_dist_infos is not None, ( + f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check" + ) processed_patterns = [] for matched_pattern in matched_patterns: # convert pattern_ops_dist_infos to program_ops_dist_infos @@ -764,9 +766,9 @@ def to_distributed( for pattern_ops_id, op_dist_info in pattern_ops_dist_infos.items(): program_ops_id = [] for pattern_op_id in pattern_ops_id: - assert ( - pattern_op_id in matched_pattern.keys() - ), f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}" + assert pattern_op_id in matched_pattern.keys(), ( + f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}" + ) program_op_id = matched_pattern[pattern_op_id] program_ops_id.append(program_op_id) program_ops_dist_infos[tuple(program_ops_id)] = op_dist_info @@ -789,9 +791,9 @@ def to_distributed( if with_mp: num_hidden_layers = len(matched_programs[DECODER_LAYER_NAME]) for pattern_name, processed_patterns in matched_programs.items(): - assert ( - len(processed_patterns) == num_hidden_layers - ), "transformer patterns matched are incomplete" + assert len(processed_patterns) == num_hidden_layers, ( + "transformer patterns matched are incomplete" + ) for idx, processed_pattern in enumerate(processed_patterns): local_mesh = mesh if with_pp: @@ -801,9 +803,9 @@ def to_distributed( local_mesh = mesh.get_mesh_with_dim("pp", pp_stage_id) for program_ops_id, dist_infos in processed_pattern.items(): - assert ( - program_ops_id in ops_id_to_layer.keys() - ), f"program_ops: {program_ops_id} is not corresponding to a dynamic layer" + assert program_ops_id in ops_id_to_layer.keys(), ( + f"program_ops: {program_ops_id} is not corresponding to a dynamic layer" + ) dynamic_layer = ops_id_to_layer[program_ops_id] mesh_num_dims = len(local_mesh.shape) sharding_info = dist_infos.get_dist_info(mesh_num_dims) @@ -832,9 +834,9 @@ def to_distributed( if decoder_layers is not None: num_decoder_blocks = len(decoder_layers) - assert ( - num_decoder_blocks == num_hidden_layers - ), f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}" + assert num_decoder_blocks == num_hidden_layers, ( + f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}" + ) pp_degree = mesh.get_dim_size("pp") num_blocks_per_stage = num_decoder_blocks // pp_degree diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index eb360f063046d2..a17e9d59a5484d 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -73,17 +73,17 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): """ if process_mesh is not None: - assert isinstance( - process_mesh, core.ProcessMesh - ), f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh" + assert isinstance(process_mesh, core.ProcessMesh), ( + f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh" + ) else: process_mesh = get_current_process_mesh() - assert ( - process_mesh is not None - ), "Specify the process mesh argument or use ProcessMesh context manager first." - assert isinstance( - shard_spec, list - ), f"Argument shard_spec {shard_spec} is not an instance of list" + assert process_mesh is not None, ( + "Specify the process mesh argument or use ProcessMesh context manager first." + ) + assert isinstance(shard_spec, list), ( + f"Argument shard_spec {shard_spec} is not an instance of list" + ) if isinstance(x, str): x = ( paddle.static.default_main_program() @@ -100,9 +100,9 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): else: tensor_shape = serial_tensor.shape if shard_spec is not None: - assert verify_shard_spec( - shard_spec, tensor_shape, process_mesh - ), f"For tensor {serial_tensor.name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {process_mesh}." + assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), ( + f"For tensor {serial_tensor.name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {process_mesh}." + ) dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping( shard_spec, process_mesh ) @@ -164,14 +164,14 @@ def shard_op( """ if process_mesh is not None: - assert isinstance( - process_mesh, ProcessMesh - ), f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh" + assert isinstance(process_mesh, ProcessMesh), ( + f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh" + ) else: process_mesh = get_current_process_mesh() - assert ( - process_mesh is not None - ), "Specify the process mesh argument or use ProcessMesh context manager first." + assert process_mesh is not None, ( + "Specify the process mesh argument or use ProcessMesh context manager first." + ) in_dims_mappings = [] if in_shard_specs is not None: assert all( diff --git a/python/paddle/distributed/auto_parallel/intermediate/context_parallel.py b/python/paddle/distributed/auto_parallel/intermediate/context_parallel.py index 424cb1733f094e..9f251a0dc9bbe9 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/context_parallel.py +++ b/python/paddle/distributed/auto_parallel/intermediate/context_parallel.py @@ -138,16 +138,16 @@ def all2all_split_input(layer, args): if isinstance(args, (list, tuple)): all_args = [] for input_tensor in args: - assert ( - input_tensor.is_dist() - ), "Input tensor must be a distributed tensor." - assert ( - len(input_tensor.shape) == 2 - ), f"input_ids should be [batch_size, seq_len], but got {input_tensor.shape}" + assert input_tensor.is_dist(), ( + "Input tensor must be a distributed tensor." + ) + assert len(input_tensor.shape) == 2, ( + f"input_ids should be [batch_size, seq_len], but got {input_tensor.shape}" + ) _, seq_len = input_tensor.shape - assert ( - seq_len % cp_degree == 0 - ), f"sequence length {seq_len} must be divisible by cp degree {cp_degree}" + assert seq_len % cp_degree == 0, ( + f"sequence length {seq_len} must be divisible by cp degree {cp_degree}" + ) reshard_input = shard_tensor(input_tensor, 1) all_args.append(reshard_input) new_args = tuple(all_args) @@ -170,21 +170,21 @@ def p2p_split_input(layer, args): all_args = [] for input_tensor in args: # check input_ids - assert ( - input_tensor.is_dist() - ), "Input tensor must be a distributed tensor." - assert ( - len(input_tensor.shape) == 2 - ), f"input_ids should be [batch_size, seq_len], but got {input_tensor.shape}" + assert input_tensor.is_dist(), ( + "Input tensor must be a distributed tensor." + ) + assert len(input_tensor.shape) == 2, ( + f"input_ids should be [batch_size, seq_len], but got {input_tensor.shape}" + ) placements = input_tensor.placements if placements is None: placements = [ dist.Replicate() for _ in range(len(process_mesh.shape)) ] - assert ( - placements[cp_index] == dist.Replicate() - ), "Input tensor must be a replicated tensor in cp mesh." + assert placements[cp_index] == dist.Replicate(), ( + "Input tensor must be a replicated tensor in cp mesh." + ) reshard_input = shard_seq_load_balance(input_tensor, 1) all_args.append(reshard_input) new_args = tuple(all_args) @@ -319,9 +319,9 @@ def all2all_reshard_hook(layer, args): assert arg.is_dist(), f"arg {arg} must be a distributed tensor." assert len(arg.shape) == 3 or len(arg.shape) == 4 placements = arg.placements - assert placements[cp_index] == dist.Shard( - 1 - ), f"arg {arg} must be sharded in sequence dimension." + assert placements[cp_index] == dist.Shard(1), ( + f"arg {arg} must be sharded in sequence dimension." + ) # reshard [batch_size,seq_len/sep,num_head,head_dim] -> [batch_size,seq_len,num_head/sep,head_dim] placements[cp_index] = dist.Shard(2) target_arg = dist.reshard(arg, process_mesh, placements) @@ -336,13 +336,13 @@ def all2all_reshard_hook(layer, input, output): cp_index = process_mesh.dim_names.index('sep') cp_degree = process_mesh.shape[cp_index] placements = output.placements - assert ( - output.is_dist() - ), f"output {output} must be a distributed tensor." + assert output.is_dist(), ( + f"output {output} must be a distributed tensor." + ) assert len(output.shape) == 4 or len(output.shape) == 3 - assert placements[cp_index] == dist.Shard( - 2 - ), f"output {output} must be Shard(2) in sequence dimension." + assert placements[cp_index] == dist.Shard(2), ( + f"output {output} must be Shard(2) in sequence dimension." + ) # reshard [batch_size,seq_len,num_head/seq,head_dim] -> [batch_size,seq_len/sep,num_head,head_dim] placements[cp_index] = dist.Shard(1) target_output = dist.reshard(output, process_mesh, placements) @@ -356,14 +356,14 @@ def input_hook(layer, args, kwargs): cp_degree = process_mesh.shape[cp_index] for arg in args: # check q k v - assert ( - arg.is_dist() - ), "Input tensor must be a distributed tensor." + assert arg.is_dist(), ( + "Input tensor must be a distributed tensor." + ) assert len(arg.shape) == 3 or len(arg.shape) == 4 placements = arg.placements - assert placements[cp_index] == dist.Shard( - 1 - ), f"arg {arg} must be Shard(1) in sequence dimension." + assert placements[cp_index] == dist.Shard(1), ( + f"arg {arg} must be Shard(1) in sequence dimension." + ) # edit kwarg backend to 'p2p' new_kwargs = kwargs new_kwargs['backend'] = 'p2p' diff --git a/python/paddle/distributed/auto_parallel/intermediate/parallel_base.py b/python/paddle/distributed/auto_parallel/intermediate/parallel_base.py index 8730ffe6fc9ad6..b81adcdf50bff9 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/parallel_base.py +++ b/python/paddle/distributed/auto_parallel/intermediate/parallel_base.py @@ -57,9 +57,9 @@ def __init__( level = str(level) assert level in ("0", "1", "2", "3", None) if optimizer.level is not None: - assert ( - level == optimizer.level - ), f"The level passed in is not identical with previous level. Current level is {level}, previous level is {optimizer.level}" + assert level == optimizer.level, ( + f"The level passed in is not identical with previous level. Current level is {level}, previous level is {optimizer.level}" + ) self.level = level self.sharding_mesh_dim = sharding_mesh_dim else: diff --git a/python/paddle/distributed/auto_parallel/intermediate/parallelize.py b/python/paddle/distributed/auto_parallel/intermediate/parallelize.py index f64005a5e411d1..f4f1058a787875 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/parallelize.py +++ b/python/paddle/distributed/auto_parallel/intermediate/parallelize.py @@ -260,9 +260,9 @@ def parallelize( return model, optimizer assert isinstance(config, dict) if mesh is not None: - assert isinstance( - mesh, core.ProcessMesh - ), "The mesh must be an instance of paddle.distributed.ProcessMesh." + assert isinstance(mesh, core.ProcessMesh), ( + "The mesh must be an instance of paddle.distributed.ProcessMesh." + ) g_mesh = fleet.auto.get_mesh() if g_mesh is not None and g_mesh != mesh: warnings.warn( @@ -322,9 +322,9 @@ def parallelize_model(model, mesh=None, config=None): return model assert isinstance(config, dict) if mesh is not None: - assert isinstance( - mesh, core.ProcessMesh - ), "The mesh must be an instance of paddle.distributed.ProcessMesh." + assert isinstance(mesh, core.ProcessMesh), ( + "The mesh must be an instance of paddle.distributed.ProcessMesh." + ) g_mesh = fleet.auto.get_mesh() if g_mesh is not None and g_mesh != mesh: warnings.warn( @@ -346,9 +346,9 @@ def parallelize_optimizer(optimizer, mesh=None, config=None): return optimizer assert isinstance(config, dict) if mesh is not None: - assert isinstance( - mesh, core.ProcessMesh - ), "The mesh must be an instance of paddle.distributed.ProcessMesh." + assert isinstance(mesh, core.ProcessMesh), ( + "The mesh must be an instance of paddle.distributed.ProcessMesh." + ) g_mesh = fleet.auto.get_mesh() if g_mesh is not None and g_mesh != mesh: warnings.warn( @@ -358,21 +358,21 @@ def parallelize_optimizer(optimizer, mesh=None, config=None): fleet.auto.set_mesh(mesh) global has_parallelized_model - assert ( - has_parallelized_model - ), "Please parallelize the model before parallelize optimizer." + assert has_parallelized_model, ( + "Please parallelize the model before parallelize optimizer." + ) param_list = optimizer._parameter_list if isinstance(param_list[0], dict): for param_group in param_list: for param in param_group['params']: - assert ( - param.is_dist() - ), "Please use model after parallelize to create optimizer." + assert param.is_dist(), ( + "Please use model after parallelize to create optimizer." + ) else: for param in param_list: - assert ( - param.is_dist() - ), "Please use model after parallelize to create optimizer." + assert param.is_dist(), ( + "Please use model after parallelize to create optimizer." + ) dp_config = config.get('dp_config') level = None diff --git a/python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py b/python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py index 85aac541cd17c9..b742dc010d3719 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py +++ b/python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py @@ -71,9 +71,9 @@ def __init__(self, model, split_spec, global_spec, pipeline_layers=None): self.name_to_layer[layer_name] = layer def get_layer_by_name(self, name): - assert ( - name in self.name_to_layer - ), f"layer name:{name} not in the model, please check the split_spec" + assert name in self.name_to_layer, ( + f"layer name:{name} not in the model, please check the split_spec" + ) return self.name_to_layer[name] def pipeline_parallel_fn(self, model): @@ -135,9 +135,9 @@ def forward_pre_hook(layer, input): pipeline_layer_mark[i] = 1 is_valid = True break - assert ( - is_valid - ), f"the last layer:{split_layer_name} must not be SplitPoint.END, please check the split_spec" + assert is_valid, ( + f"the last layer:{split_layer_name} must not be SplitPoint.END, please check the split_spec" + ) else: raise NotImplementedError( "SplitPoint.BEGINNING is not supported currently" @@ -288,12 +288,12 @@ def pipeline_parallel(model, optimizer=None, config=None): return model, optimizer mesh = fleet.auto.get_mesh() - assert ( - mesh is not None - ), "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" - assert ( - "pp" in mesh.dim_names - ), "pp must in the mesh dim_names when use pipeline_parallel" + assert mesh is not None, ( + "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" + ) + assert "pp" in mesh.dim_names, ( + "pp must in the mesh dim_names when use pipeline_parallel" + ) global_spec = config.get("global_spec") if isinstance(split_spec, str): @@ -336,12 +336,12 @@ def filter_matched_layer(matched_layer_name): matched_layer_name = filter_matched_layer(matched_layer_name) pp_size = mesh.get_dim_size("pp") layer_num = len(matched_layer_name) - assert ( - layer_num > 0 - ), "No layer match the split_spec, please check its correctness" - assert ( - layer_num >= pp_size - ), "The number of layers must not be less than the pp size" + assert layer_num > 0, ( + "No layer match the split_spec, please check its correctness" + ) + assert layer_num >= pp_size, ( + "The number of layers must not be less than the pp size" + ) if layer_num % pp_size != 0: logger.warning( f"The number of layers({layer_num}) must be divisible by the pp size({pp_size}), but got {layer_num} and {pp_size}" @@ -383,18 +383,18 @@ def divide_list_indices(n, k): sublayer_names = [name for name, _ in model.named_sublayers()] split_spec_dict = split_spec for key, value in split_spec_dict.items(): - assert ( - key in sublayer_names - ), f"wrong split layer, expected one of {sublayer_names}" + assert key in sublayer_names, ( + f"wrong split layer, expected one of {sublayer_names}" + ) assert value is SplitPoint.END, "not supported split point at now." if global_spec: if isinstance(global_spec, str): global_spec = [global_spec] else: - assert isinstance( - global_spec, (list, tuple) - ), f"global_spec can only be list or list(str), but got:{type(global_spec)}" + assert isinstance(global_spec, (list, tuple)), ( + f"global_spec can only be list or list(str), but got:{type(global_spec)}" + ) logger.info( f"split_spec_dict: {split_spec_dict}, global_spec: {global_spec}, matched_layer_name: {matched_layer_name}" diff --git a/python/paddle/distributed/auto_parallel/intermediate/sharded_data_parallel.py b/python/paddle/distributed/auto_parallel/intermediate/sharded_data_parallel.py index 6f935a51c1288a..e1ef846515e333 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/sharded_data_parallel.py +++ b/python/paddle/distributed/auto_parallel/intermediate/sharded_data_parallel.py @@ -79,10 +79,10 @@ def sharded_data_parallel(model, optimizer=None, config=None): # check global_mesh mesh = fleet.auto.get_mesh() - assert ( - mesh is not None - ), "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" - assert ( - "dp" in mesh.dim_names - ), "dp must in the mesh dim_names when use sharded_data_parallel" + assert mesh is not None, ( + "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" + ) + assert "dp" in mesh.dim_names, ( + "dp must in the mesh dim_names when use sharded_data_parallel" + ) return sdp_model, optimizer diff --git a/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py b/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py index 8ea0aa0c3d5086..1ff6d5c2cccd54 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py +++ b/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py @@ -821,15 +821,15 @@ def __init__(self, model, parallelize_plan=None): if parallelize_plan is not None: assert isinstance(parallelize_plan, dict) for key, plan in parallelize_plan.items(): - assert isinstance( - key, str - ), "The key of the parallelize plan should be a string." + assert isinstance(key, str), ( + "The key of the parallelize plan should be a string." + ) if not isinstance(plan, list): plan = [plan] for p in plan: - assert isinstance( - p, PlanBase - ), "The value the the parallelize plan should be a instance of PlanBase or a list of PlanBase." + assert isinstance(p, PlanBase), ( + "The value the the parallelize plan should be a instance of PlanBase or a list of PlanBase." + ) self.global_mesh = dist.auto_parallel.get_mesh() self.parallelize_plan = parallelize_plan @@ -934,12 +934,12 @@ def tensor_parallel(model, optimizer=None, config=None): global_mesh = dist.auto_parallel.get_mesh() - assert ( - global_mesh is not None - ), "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" - assert ( - "mp" in global_mesh.dim_names - ), "mp must in the mesh dim_names when use tensor_parallel" + assert global_mesh is not None, ( + "global mesh must not be None, please call fleet.auto.set_mesh(global_mesh) firstly" + ) + assert "mp" in global_mesh.dim_names, ( + "mp must in the mesh dim_names when use tensor_parallel" + ) model = TensorParallel(model, parallelize_plan) if optimizer is not None: diff --git a/python/paddle/distributed/auto_parallel/local_layer.py b/python/paddle/distributed/auto_parallel/local_layer.py index c7e24d65225bf3..74456a66ec562b 100644 --- a/python/paddle/distributed/auto_parallel/local_layer.py +++ b/python/paddle/distributed/auto_parallel/local_layer.py @@ -113,9 +113,9 @@ def __call__(self, *inputs: Any, **kwargs: Any) -> Any: outputs back to distributed tensors based on the specified distribution attributes. """ inputs = list(inputs) - assert len(inputs) == len( - self.grad_dist_attrs - ), f"The number of inputs ({len(inputs)}) does not match the number of grad_dist_attrs ({len(self.grad_dist_attrs)})." + assert len(inputs) == len(self.grad_dist_attrs), ( + f"The number of inputs ({len(inputs)}) does not match the number of grad_dist_attrs ({len(self.grad_dist_attrs)})." + ) for idx in range(len(inputs)): if inputs[idx].is_dist(): if self.grad_dist_attrs[idx] is None: @@ -141,9 +141,9 @@ def __call__(self, *inputs: Any, **kwargs: Any) -> Any: outputs = Layer.__call__(self, *inputs, **kwargs) list_outs = paddle.utils.flatten(outputs) - assert len(list_outs) == len( - self.out_dist_attrs - ), f"The number of outputs ({len(list_outs)}) does not match the number of distribution attributes ({len(self.out_dist_attrs)})." + assert len(list_outs) == len(self.out_dist_attrs), ( + f"The number of outputs ({len(list_outs)}) does not match the number of distribution attributes ({len(self.out_dist_attrs)})." + ) dist_outs = [] for idx in range(len(list_outs)): diff --git a/python/paddle/distributed/auto_parallel/local_map.py b/python/paddle/distributed/auto_parallel/local_map.py index e9655064c3dca5..80b9ba0aa7659a 100644 --- a/python/paddle/distributed/auto_parallel/local_map.py +++ b/python/paddle/distributed/auto_parallel/local_map.py @@ -203,9 +203,9 @@ def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs): for out, out_placement in zip(flat_out, out_placements): if paddle.in_dynamic_mode(): if isinstance(out, paddle.Tensor): - assert not dist.auto_parallel.api.is_dist_tensor( - out - ), f"Expected dense tensor output but got {type(out)}: {out}" + assert not dist.auto_parallel.api.is_dist_tensor(out), ( + f"Expected dense tensor output but got {type(out)}: {out}" + ) flat_dist_and_arg_out.append( dist.auto_parallel.api.dtensor_from_local( @@ -220,9 +220,9 @@ def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs): flat_dist_and_arg_out.append(out) else: if isinstance(out, paddle.base.libpaddle.pir.Value): - assert not dist.auto_parallel.api.is_dist_tensor( - out - ), f"Expected dense tensor output but got {type(out)}: {out}" + assert not dist.auto_parallel.api.is_dist_tensor(out), ( + f"Expected dense tensor output but got {type(out)}: {out}" + ) flat_dist_and_arg_out.append( dist.auto_parallel.api.dtensor_from_local( @@ -241,9 +241,9 @@ def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs): flat_dist_and_arg_out = [] for out, out_placement in zip(flat_out, out_placements): if out_placement is not None: - assert ( - process_mesh is not None - ), "process_mesh must be specified when out_placements is not None" + assert process_mesh is not None, ( + "process_mesh must be specified when out_placements is not None" + ) flat_dist_and_arg_out.append( dist.auto_parallel.api.dtensor_from_local( out, process_mesh, out_placement diff --git a/python/paddle/distributed/auto_parallel/moe_utils.py b/python/paddle/distributed/auto_parallel/moe_utils.py index 2c050a45dffe28..7155132e076a0a 100644 --- a/python/paddle/distributed/auto_parallel/moe_utils.py +++ b/python/paddle/distributed/auto_parallel/moe_utils.py @@ -104,12 +104,12 @@ def _dtensor_from_local( # TODO Adopt Mix2Dist Pass to allow the program could be executed actually. elif paddle.framework.in_pir_mode(): - assert isinstance( - local_tensor, (type(None), paddle.pir.Value) - ), "input tensor is not pir value." - assert ( - local_tensor.is_dense_tensor_type() - ), "dtensor_from_local() are only supported dense tensor type right." + assert isinstance(local_tensor, (type(None), paddle.pir.Value)), ( + "input tensor is not pir value." + ) + assert local_tensor.is_dense_tensor_type(), ( + "dtensor_from_local() are only supported dense tensor type right." + ) sharding_specs = ( paddle.distributed.auto_parallel.placement_type.get_shard_spec( mesh, placements, local_tensor.ndim @@ -246,9 +246,9 @@ def infer_positive_shape(src_shape, tgt_shape): minus_one_idx = np.where(ret_shape == -1)[0] if minus_one_idx.size > 0: - assert ( - minus_one_idx.size <= 1 - ), "At most one -1 is allowed in target shape." + assert minus_one_idx.size <= 1, ( + "At most one -1 is allowed in target shape." + ) nelem = np.prod(src_shape) ret_shape[minus_one_idx[0]] = 1 @@ -340,9 +340,9 @@ def _dist_reshape( "dist_reshape is only supported in dynamic and pir mode." ) - assert np.prod(tgt_local_shape) == np.prod( - src_local_shape - ), f"The local shapes {src_local_shape} and {tgt_local_shape} are mismatched." + assert np.prod(tgt_local_shape) == np.prod(src_local_shape), ( + f"The local shapes {src_local_shape} and {tgt_local_shape} are mismatched." + ) if paddle.in_dynamic_mode(): return _local_reshape.apply( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py index 2875d91d136059..09460206863aa5 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py @@ -64,9 +64,9 @@ def forward(ctx, *args, **kwargs): and not op_dist_attr.is_recompute and rank_id in op_dist_attr.process_mesh.process_ids ): - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) if ( len(kwargs.get('fixed_seed_offset', [])) > 0 diff --git a/python/paddle/distributed/auto_parallel/pipelining/_backward.py b/python/paddle/distributed/auto_parallel/pipelining/_backward.py index edcec2819c5e73..382cd0f0788a09 100644 --- a/python/paddle/distributed/auto_parallel/pipelining/_backward.py +++ b/python/paddle/distributed/auto_parallel/pipelining/_backward.py @@ -75,17 +75,17 @@ def extract_tensors_with_grads( if isinstance(output_val, paddle.Tensor): if output_val.stop_gradient and output_val.grad_fn is None: return - assert isinstance( - grad_val, (paddle.Tensor, type(None)) - ), f"Expected Tensor or None gradient but got {type(grad_val)}" + assert isinstance(grad_val, (paddle.Tensor, type(None))), ( + f"Expected Tensor or None gradient but got {type(grad_val)}" + ) stage_output_tensors.append(output_val) output_grad_tensors.append(grad_val) elif isinstance(output_val, (tuple, list)): if grad_val is None: return - assert isinstance( - grad_val, (tuple, list) - ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + assert isinstance(grad_val, (tuple, list)), ( + f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + ) assert len(output_val) == len(grad_val) for ov, gv in zip(output_val, grad_val): extract_tensors_with_grads( diff --git a/python/paddle/distributed/auto_parallel/pipelining/microbatch.py b/python/paddle/distributed/auto_parallel/pipelining/microbatch.py index cc3fd292c92df2..30623dfa14baa8 100644 --- a/python/paddle/distributed/auto_parallel/pipelining/microbatch.py +++ b/python/paddle/distributed/auto_parallel/pipelining/microbatch.py @@ -42,9 +42,9 @@ def _split_tensor(x, num_chunks, split_axis=0): def _reorder_data_for_align(): nonlocal x - assert x.placements[0] == dist.Shard( - 0 - ), "inputs should be placed on S(0)." + assert x.placements[0] == dist.Shard(0), ( + "inputs should be placed on S(0)." + ) shardings = x.process_mesh.shape[0] @@ -116,9 +116,9 @@ def _split_args_helper( """ A helper function of split_args_kwargs_into_chunks. """ - assert len(args_dict) == len( - args_chunk_spec - ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + assert len(args_dict) == len(args_chunk_spec), ( + f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + ) shared_args_dict_flat = {} # handle args one by one @@ -129,9 +129,9 @@ def _split_args_helper( assert chunk_spec is not None chunk_spec_flat = flatten(chunk_spec) - assert len(chunk_spec_flat) == len( - arg_flat - ), f"{arg_key} {len(arg_flat)} != {len(chunk_spec_flat)}" + assert len(chunk_spec_flat) == len(arg_flat), ( + f"{arg_key} {len(arg_flat)} != {len(chunk_spec_flat)}" + ) shard_arg_flat = [] @@ -280,9 +280,9 @@ def merge_chunks( chunk_spec = flatten(chunk_spec) for chunk in chunks: chunk_flat = flatten(chunk) - assert len(chunk_flat) == len( - chunk_spec - ), f"Chunk {chunk} did not match chunk spec {chunk_spec}" + assert len(chunk_flat) == len(chunk_spec), ( + f"Chunk {chunk} did not match chunk spec {chunk_spec}" + ) chunks_flat.append(chunk_flat) def _merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None): diff --git a/python/paddle/distributed/auto_parallel/pipelining/schedules.py b/python/paddle/distributed/auto_parallel/pipelining/schedules.py index 7d738edefcf4b8..bd122e232421c4 100644 --- a/python/paddle/distributed/auto_parallel/pipelining/schedules.py +++ b/python/paddle/distributed/auto_parallel/pipelining/schedules.py @@ -860,9 +860,9 @@ def _step_microbatches( computation_type = action.computation_type mb_index = action.microbatch_index stage_index = action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) if computation_type == _ActType.FORWARD: # perform forward computation stage = stage_index_to_stage[stage_index] @@ -922,9 +922,9 @@ def _step_microbatches( computation_type = prev_rank_action.computation_type mb_index = prev_rank_action.microbatch_index stage_index = prev_rank_action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) # Only handle sends for the forward from a previous rank if computation_type == _ActType.FORWARD: # If not the last stage, then receive fwd activations @@ -953,9 +953,9 @@ def _step_microbatches( computation_type = next_rank_action.computation_type mb_index = next_rank_action.microbatch_index stage_index = next_rank_action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) # Only handle receives for the backwards from a next rank if computation_type in (FORWARD, BACKWARD_WEIGHT): # Next rank doing forward or weight update has no influence for the current rank backward recv diff --git a/python/paddle/distributed/auto_parallel/pipelining/stage.py b/python/paddle/distributed/auto_parallel/pipelining/stage.py index 797ea66970aba5..1af80831cdee71 100644 --- a/python/paddle/distributed/auto_parallel/pipelining/stage.py +++ b/python/paddle/distributed/auto_parallel/pipelining/stage.py @@ -204,9 +204,9 @@ def __init__( # Forward infra self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} self.act_send_info: dict[int, list] = {} - self._need_grad_indices: dict[int, list] = ( - {} - ) # record the index of output that needs to receive grad from the next stage. + self._need_grad_indices: dict[ + int, list + ] = {} # record the index of output that needs to receive grad from the next stage. # Backward infra will created lazily self.grad_recv_info: dict = {} self.grad_send_info: list | None = None @@ -260,16 +260,16 @@ def _configure_outputs_meta(self, outputs_meta: tuple[paddle.Tensor, ...]): configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches which could show up as hangs, silent corruption, or other errors. """ - assert ( - self._outputs_meta is None - ), "Attempting to reconfigure output_meta, which is not supported" + assert self._outputs_meta is None, ( + "Attempting to reconfigure output_meta, which is not supported" + ) self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] def get_outputs_meta(self) -> tuple[paddle.Tensor, ...]: """Get the output metadata (meta tensors) representing the outputs of this stage""" - assert ( - self._outputs_meta is not None - ), "Attempted to get_outputs_meta() without configuring output meta" + assert self._outputs_meta is not None, ( + "Attempted to get_outputs_meta() without configuring output meta" + ) return self._outputs_meta def _create_grad_send_info( @@ -376,12 +376,12 @@ def set_local_fwd_input( ) for info, tensor in zip(recv_infos, prev_stage_outputs): - assert isinstance( - tensor, paddle.Tensor - ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" - assert isinstance( - info, _RecvInfo - ), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + assert isinstance(tensor, paddle.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + ) info.buffer = _detach_and_requires_grad(tensor) @@ -389,9 +389,9 @@ def get_local_bwd_output(self, mb_index): """ Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. """ - assert ( - self.has_backward - ), "can't steal_bwd_input if this stage doesn't have backward" + assert self.has_backward, ( + "can't steal_bwd_input if this stage doesn't have backward" + ) assert not self.is_first, "can't get bwd output if this stage is first" self._check_chunk_id(mb_index) @@ -406,22 +406,22 @@ def set_local_bwd_input( Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. Does not detach or set 'stop_gradient'. """ - assert isinstance( - next_stage_bwd_outputs, tuple - ), f"Expected tuple, got {type(next_stage_bwd_outputs)}" + assert isinstance(next_stage_bwd_outputs, tuple), ( + f"Expected tuple, got {type(next_stage_bwd_outputs)}" + ) - assert ( - self.has_backward - ), "can't set bwd input if this stage doesn't have backward" + assert self.has_backward, ( + "can't set bwd input if this stage doesn't have backward" + ) assert not self.is_last, "can't set bwd input if this stage is last" recv_infos = self.grad_recv_info[mb_index] for info, tensor in zip(recv_infos, next_stage_bwd_outputs): - assert isinstance( - tensor, paddle.Tensor - ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" - assert isinstance( - info, _RecvInfo - ), f"Expected a recv info, got {type(info)}" + assert isinstance(tensor, paddle.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + f"Expected a recv info, got {type(info)}" + ) info.buffer = tensor def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: @@ -902,9 +902,9 @@ def __init__( else input_args ) - assert ( - output_args is not None - ), "If passing input_args, also pass output_args to override shape inference" + assert output_args is not None, ( + "If passing input_args, also pass output_args to override shape inference" + ) self._configure_outputs_meta( (output_args,) if isinstance(output_args, TensorMeta) @@ -977,28 +977,30 @@ def _sync_shared_param(self): def _validate_shared_parameter_pair(self): # Validate shared_parameters structure. - assert isinstance( - self.shared_parameters, list - ), f"Expected `shared_parameters` to return a list, but got {type(self.shared_parameters).__name__}. " + assert isinstance(self.shared_parameters, list), ( + f"Expected `shared_parameters` to return a list, but got {type(self.shared_parameters).__name__}. " + ) # Validate every pair shard parameter. for idx, a_shared_map in enumerate(self.shared_parameters): # Validate map structure. - assert isinstance( - a_shared_map, dict - ), f"Invalid shared parameter pair: expected dict, but got {type(a_shared_map).__name__}." + assert isinstance(a_shared_map, dict), ( + f"Invalid shared parameter pair: expected dict, but got {type(a_shared_map).__name__}." + ) assert len(a_shared_map) <= 3, ( f"shared_parameters['{idx}'] exceeds size limit (max 3 keys). " f"Allowed: ['params', 'group', 'shared_param'], got: {list(a_shared_map.keys())}" ) # Validate required 'params' entry. params = a_shared_map.get("params") - assert ( - params is not None - ), f"Missing shared parameter 'params' not found in shared_parameters['{idx}']. Available keys: {list(a_shared_map)}." + assert params is not None, ( + f"Missing shared parameter 'params' not found in shared_parameters['{idx}']. Available keys: {list(a_shared_map)}." + ) assert (isinstance(params, list) or tuple(params, list)) and len( params - ) == 2, f"Shared parameter only support 2 shared parameters in list or tuple, but got {len(params)}." + ) == 2, ( + f"Shared parameter only support 2 shared parameters in list or tuple, but got {len(params)}." + ) # Validate parameter types and placements. param_1, param_2 = params assert isinstance(param_1, EagerParamBase) and isinstance( @@ -1015,24 +1017,26 @@ def _validate_shared_parameter_pair(self): ranks_1 = param_1.process_mesh.process_ids ranks_2 = param_2.process_mesh.process_ids assert len(ranks_1) == len(ranks_2) - assert ( - ranks_1 != ranks_2 - ), f"Shared parameters must be on different stage meshes, but both are on {ranks_1}." + assert ranks_1 != ranks_2, ( + f"Shared parameters must be on different stage meshes, but both are on {ranks_1}." + ) # In VPP mode, a same shared_parameters is reused across stage builds. To avoid redundant group creation, the 'shared_param' # and 'group' attributes may already exist, as they are created during the `_init_shared_group` call. # Validate optional 'group' entry. if "group" in a_shared_map: group = a_shared_map["group"] - assert group is None or isinstance( - group, Group - ), f"Expected 'shared_parameters[{idx}][\"group\"]' is 'Group' or None, but got '{type(a_shared_map['group']).__name__}'." + assert group is None or isinstance(group, Group), ( + f"Expected 'shared_parameters[{idx}][\"group\"]' is 'Group' or None, but got '{type(a_shared_map['group']).__name__}'." + ) # Validate optional 'sync_param' entry. if "sync_param" in a_shared_map: sync_param = a_shared_map["sync_param"] assert sync_param is None or sync_param in list( param_1, param_2 - ), f"Expected 'shared_parameters[{idx}][\"sync_param\"]' is one of the two params or None." + ), ( + f"Expected 'shared_parameters[{idx}][\"sync_param\"]' is one of the two params or None." + ) def _init_shared_group(self): # Retrieve the parameters to be shared and the required communication group information for the current rank, and store them in @@ -1054,9 +1058,9 @@ def _init_shared_group(self): # In VPP mode, since `shared_parameters`` is reused across stage creations, # the 'group' may already exist, avoiding redundant group creation. if cur_rank in group_ranks: - assert group_ranks == tuple( - a_map["group"].ranks - ), f"Shared Parameter group ranks mismatch: expected {group_ranks}, but got {a_map['group'].ranks}. " + assert group_ranks == tuple(a_map["group"].ranks), ( + f"Shared Parameter group ranks mismatch: expected {group_ranks}, but got {a_map['group'].ranks}. " + ) else: if group_ranks not in get_group_from_ranks: get_group_from_ranks[group_ranks] = dist.new_group( @@ -1126,9 +1130,9 @@ def _shape_inference( ): raise NotImplementedError else: - assert ( - len(args) == 0 - ), "Can't supply input args for shape inference on non-first stage" + assert len(args) == 0, ( + "Can't supply input args for shape inference on non-first stage" + ) objects = [None] logger.debug( "Shape inference: stage %s receiving from stage %s", diff --git a/python/paddle/distributed/auto_parallel/pipelining/utils.py b/python/paddle/distributed/auto_parallel/pipelining/utils.py index 5de9c3832ec067..a23d7c08f50643 100644 --- a/python/paddle/distributed/auto_parallel/pipelining/utils.py +++ b/python/paddle/distributed/auto_parallel/pipelining/utils.py @@ -134,9 +134,9 @@ def _get_pp_mesh(pp_idx=0, pp_dim_names="pp"): Get the mesh of the {pp_idx}th PipelineStage. """ mesh = fleet.auto.get_mesh() - assert ( - mesh is not None - ), "the mesh is None, please call fleet.auto.set_mesh first." + assert mesh is not None, ( + "the mesh is None, please call fleet.auto.set_mesh first." + ) if "pp" in mesh.dim_names: mesh = mesh.get_mesh_with_dim("pp", pp_idx) else: diff --git a/python/paddle/distributed/auto_parallel/placement_type.py b/python/paddle/distributed/auto_parallel/placement_type.py index b9cc1bad7a9aa2..30b975a91555c7 100644 --- a/python/paddle/distributed/auto_parallel/placement_type.py +++ b/python/paddle/distributed/auto_parallel/placement_type.py @@ -140,9 +140,9 @@ def placemetns_to_dist_status( split_factor_map[i] = cast( "Shard", placement ).get_split_factor() - assert ( - len(split_factor_map) == 1 - ), "only support to rerrange at one mesh dim." + assert len(split_factor_map) == 1, ( + "only support to rerrange at one mesh dim." + ) if placement.is_partial(): partial_status[i] = cast("Partial", placement).reduce_type() diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 3c968d8f6c5b02..544915ee9b5234 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -160,28 +160,28 @@ def __init__( self._shape = list(self._mesh.shape) self._process_ids = self._mesh.flatten().tolist() - assert all( - isinstance(p, int) for p in self._process_ids - ), "All elements of the mesh must be integer" - assert ( - min(self._process_ids) >= 0 - ), 'All elements of the mesh must be >= 0.' + assert all(isinstance(p, int) for p in self._process_ids), ( + "All elements of the mesh must be integer" + ) + assert min(self._process_ids) >= 0, ( + 'All elements of the mesh must be >= 0.' + ) unique_process_ids = set(self._process_ids) - assert len(unique_process_ids) == len( - self._process_ids - ), 'All elements of the mesh must be unique.' + assert len(unique_process_ids) == len(self._process_ids), ( + 'All elements of the mesh must be unique.' + ) if dim_names is not None: - assert len(dim_names) == len( - self._shape - ), "The length of dims_names must be same as the shape of the mesh." + assert len(dim_names) == len(self._shape), ( + "The length of dims_names must be same as the shape of the mesh." + ) self._dim_names = copy.deepcopy(dim_names) else: self._dim_names = ["d" + str(i) for i in range(len(self._shape))] unique_dim_names = set(self._dim_names) - assert len(unique_dim_names) == len( - self._dim_names - ), f'All dim_names {dim_names} must be unique.' + assert len(unique_dim_names) == len(self._dim_names), ( + f'All dim_names {dim_names} must be unique.' + ) # Follow the requirement for using pybind11 core.ProcessMesh.__init__( @@ -296,9 +296,9 @@ def get_mesh_with_dim( dim_name: str, index: slice | tuple[slice, ...] | SupportsIndex | None = None, ) -> ProcessMesh: - assert ( - dim_name in self._dim_names - ), f'{dim_name} is not a valid dim name.' + assert dim_name in self._dim_names, ( + f'{dim_name} is not a valid dim name.' + ) index_axis = self._dim_names.index(dim_name) new_order = [index_axis] + [ i for i in range(len(self._dim_names)) if i != index_axis diff --git a/python/paddle/distributed/auto_parallel/random.py b/python/paddle/distributed/auto_parallel/random.py index 7cddbc753abf0e..1e32002bb524f3 100644 --- a/python/paddle/distributed/auto_parallel/random.py +++ b/python/paddle/distributed/auto_parallel/random.py @@ -79,12 +79,12 @@ def determinate_rng( rank, dims_mapping=None, process_mesh=None, placements=None ): assert process_mesh is not None, "Must provide process mesh" - assert ( - dims_mapping is not None or placements is not None - ), "Must provide one of dims mapping or placements." - assert not ( - dims_mapping is not None and placements is not None - ), "Cannot provide dims mapping and placements at same time." + assert dims_mapping is not None or placements is not None, ( + "Must provide one of dims mapping or placements." + ) + assert not (dims_mapping is not None and placements is not None), ( + "Cannot provide dims mapping and placements at same time." + ) # TODO(JZ-LIANG) Support Mesh with any high rank # use a string to unique integer hashing algorithm for seed computation. # instead of using offsets to coordinate seed across devices. @@ -129,9 +129,9 @@ def determinate_rng( if sharding_expr in _rng_name_to_seed: assert _rng_name_to_seed[sharding_expr] == seed_ else: - assert ( - seed_ not in _rng_name_to_seed.values() - ), f"Seed Conflict! current seed: {seed_}, current sharding expr: {sharding_expr}, generated seed: {_rng_name_to_seed}" + assert seed_ not in _rng_name_to_seed.values(), ( + f"Seed Conflict! current seed: {seed_}, current sharding expr: {sharding_expr}, generated seed: {_rng_name_to_seed}" + ) _rng_name_to_seed[sharding_expr] = seed_ if paddle.in_dynamic_mode(): # for dygraph, just init the seed when meeting a new seed @@ -145,9 +145,9 @@ def determinate_rng( @contextlib.contextmanager def rng_state(name): global _rng_name_to_states - assert ( - name in _rng_name_to_states - ), f"The rng state name {name} haven't been init. " + assert name in _rng_name_to_states, ( + f"The rng state name {name} haven't been init. " + ) orig_rng_state = paddle.get_rng_state() paddle.set_rng_state(_rng_name_to_states[name]) try: diff --git a/python/paddle/distributed/auto_parallel/sharding.py b/python/paddle/distributed/auto_parallel/sharding.py index 863da28aa7ac00..bbbc5e62c7a2dd 100644 --- a/python/paddle/distributed/auto_parallel/sharding.py +++ b/python/paddle/distributed/auto_parallel/sharding.py @@ -55,9 +55,9 @@ def get_placement_with_sharding(param, sharding_axis, param_placements=None): if isinstance(placement, dist.Shard): # the parameter can't be shard twice with sharding on different mesh now # for example, [Shard(0), Shard(1)], assert here in case - assert ( - shard_axis == -1 - ), "The parameter can't be shard twice with sharding strategy even in different mesh now." + assert shard_axis == -1, ( + "The parameter can't be shard twice with sharding strategy even in different mesh now." + ) shard_axis = placement.get_dim() placement_with_sharding = None @@ -99,12 +99,14 @@ class ShardingOptimizerStage1(Optimizer): """ def __init__(self, optimizer, shard_fn=None, strategy=None): - assert ( - optimizer is not None - ), "The argument `optimizer` cannot be empty." + assert optimizer is not None, ( + "The argument `optimizer` cannot be empty." + ) assert isinstance( optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD) - ), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now." + ), ( + "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now." + ) self.__dict__["_inner_opt"] = optimizer self._shard_fn = shard_fn self._strategy = strategy or Strategy() @@ -181,15 +183,17 @@ def apply_gradients(self, params_grads): continue param_dist_attr = param.dist_attr() grad_dist_attr = grad.dist_attr() - assert ( - param_dist_attr is not None - ), f"parameter dist attribute must not None. but received {param.name} : {param}." - assert ( - grad_dist_attr is not None - ), f"gradient dist attribute must not None. but received {param.name} grad : {grad}." + assert param_dist_attr is not None, ( + f"parameter dist attribute must not None. but received {param.name} : {param}." + ) + assert grad_dist_attr is not None, ( + f"gradient dist attribute must not None. but received {param.name} grad : {grad}." + ) assert ( param_dist_attr.process_mesh == grad_dist_attr.process_mesh - ), f"Parameter and grad should have same process_mesh. but received name:{param.name}, parameter:{param}, grad: {grad}." + ), ( + f"Parameter and grad should have same process_mesh. but received name:{param.name}, parameter:{param}, grad: {grad}." + ) if self._sharding_axis not in grad_dist_attr.partial_dims: new_params_grads.append((param, grad)) @@ -204,9 +208,9 @@ def apply_gradients(self, params_grads): else: param.optimize_attr["no_fusion"] = False - assert ( - param_dist_attr.process_mesh in self.pp_meshes - ), f"parameter mesh mush be in pp_meshes. but received parameter name:{param.name}, mesh:{param_dist_attr.process_mesh}, pp_meshes: {self.pp_meshes}." + assert param_dist_attr.process_mesh in self.pp_meshes, ( + f"parameter mesh mush be in pp_meshes. but received parameter name:{param.name}, mesh:{param_dist_attr.process_mesh}, pp_meshes: {self.pp_meshes}." + ) if dist.get_rank() in param_dist_attr.process_mesh.process_ids: sub_mesh = get_1D_sub_process_mesh( @@ -214,20 +218,24 @@ def apply_gradients(self, params_grads): ) assert ( sorted(sub_mesh.process_ids) == self._sharding_group.ranks - ), f" all parameter must have the same sharding group. but received {param.name} sharding group is : {sub_mesh.process_ids}, global sharding group is: {self._sharding_group.ranks}" + ), ( + f" all parameter must have the same sharding group. but received {param.name} sharding group is : {sub_mesh.process_ids}, global sharding group is: {self._sharding_group.ranks}" + ) - assert ( - param_dist_attr.partial_dims == set() - ), f"Sharding fusion do not support partial parameter. but received {param.name} : {param}." + assert param_dist_attr.partial_dims == set(), ( + f"Sharding fusion do not support partial parameter. but received {param.name} : {param}." + ) assert ( param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping - ), f"Parameter and grad should have same dims_mapping. but received name:{param.name}, parameter:{param}, grad: {grad}." - assert ( - param.shape == grad.shape - ), f"Parameter and grad should have same global shape. but received name:{param.name}, parameter:{param}, grad: {grad}." - assert ( - param._local_shape == grad._local_shape - ), f"Parameter and grad should have same local shape. but received name:{param.name}, parameter:{param}, grad: {grad}." + ), ( + f"Parameter and grad should have same dims_mapping. but received name:{param.name}, parameter:{param}, grad: {grad}." + ) + assert param.shape == grad.shape, ( + f"Parameter and grad should have same global shape. but received name:{param.name}, parameter:{param}, grad: {grad}." + ) + assert param._local_shape == grad._local_shape, ( + f"Parameter and grad should have same local shape. but received name:{param.name}, parameter:{param}, grad: {grad}." + ) if ( self._mp_degree > 1 @@ -501,9 +509,9 @@ def _cache_slice_param_group_info(self, parameters, group_indices): for index in indices: param = parameters[index] self._slice_param_group_info[group_idx][param.name] = {} - self._slice_param_group_info[group_idx][param.name][ - "shape" - ] = param.shape + self._slice_param_group_info[group_idx][param.name]["shape"] = ( + param.shape + ) self._slice_param_group_info[group_idx][param.name][ "param_start" ] = -1 @@ -531,14 +539,14 @@ def _cache_slice_param_range_and_size( ] = param_end for name, padded_size in padded_size_dict.items(): - self._slice_param_group_info[group_idx][name][ - "padded_size" - ] = padded_size + self._slice_param_group_info[group_idx][name]["padded_size"] = ( + padded_size + ) for name, _ in self._slice_param_group_info[group_idx].items(): - self._slice_param_group_info[group_idx][name][ - "align_size" - ] = align_size + self._slice_param_group_info[group_idx][name]["align_size"] = ( + align_size + ) def _reduce_scatter_overlap(self, group_grad_list, target_block): ''' diff --git a/python/paddle/distributed/auto_parallel/static/auto_align_tool.py b/python/paddle/distributed/auto_parallel/static/auto_align_tool.py index fc37b09b1599aa..84ba2ea510eff3 100644 --- a/python/paddle/distributed/auto_parallel/static/auto_align_tool.py +++ b/python/paddle/distributed/auto_parallel/static/auto_align_tool.py @@ -117,9 +117,9 @@ def get_loss_lr_var(self): for block in self._blocks: for op in block.ops: if is_loss_op(op): - assert ( - len(op.desc.output_arg_names()) == 1 - ), "loss op should only output loss var" + assert len(op.desc.output_arg_names()) == 1, ( + "loss op should only output loss var" + ) loss_ops.append(op) for block in self._blocks: diff --git a/python/paddle/distributed/auto_parallel/static/cluster_v2.py b/python/paddle/distributed/auto_parallel/static/cluster_v2.py index 479dbdfb57493c..8a8f54e24e65cd 100644 --- a/python/paddle/distributed/auto_parallel/static/cluster_v2.py +++ b/python/paddle/distributed/auto_parallel/static/cluster_v2.py @@ -85,21 +85,21 @@ def __init__(self, name, mesh, dim_names=None): self._shape = list(self._mesh.shape) self._device_ids = self._mesh.flatten().tolist() - assert all( - isinstance(p, int) for p in self._device_ids - ), "All elements of the mesh be integer" - assert ( - min(self._device_ids) >= 0 - ), 'All elements of the mesh must be >= 0.' + assert all(isinstance(p, int) for p in self._device_ids), ( + "All elements of the mesh be integer" + ) + assert min(self._device_ids) >= 0, ( + 'All elements of the mesh must be >= 0.' + ) unique_device_ids = set(self._device_ids) - assert len(unique_device_ids) == len( - self._device_ids - ), 'All elements of the mesh must be unique.' + assert len(unique_device_ids) == len(self._device_ids), ( + 'All elements of the mesh must be unique.' + ) if dim_names is not None: - assert len(dim_names) == len( - self._shape - ), "The length of dims_names must be same as the shape of the mesh." + assert len(dim_names) == len(self._shape), ( + "The length of dims_names must be same as the shape of the mesh." + ) self._dim_names = dim_names else: self._dim_names = ["d" + str(i) for i in range(len(self._shape))] diff --git a/python/paddle/distributed/auto_parallel/static/completion.py b/python/paddle/distributed/auto_parallel/static/completion.py index d55f8e58d8b805..1ca5261bcf6227 100644 --- a/python/paddle/distributed/auto_parallel/static/completion.py +++ b/python/paddle/distributed/auto_parallel/static/completion.py @@ -1251,19 +1251,19 @@ def set_process_mesh(block, op, process_mesh, var_to_process_mesh): seg_op_deps[struct_name] = [i] seg_op_mesh[struct_name] = dist_op.dist_attr.process_mesh else: - assert ( - seg_op_deps[struct_name][-1] + 1 == i - ), "The segment's ops should be continuous." + assert seg_op_deps[struct_name][-1] + 1 == i, ( + "The segment's ops should be continuous." + ) pre_mesh = seg_op_mesh[struct_name] - assert ( - pre_mesh == dist_op.dist_attr.process_mesh - ), "The segment's ops should have same process_mesh." + assert pre_mesh == dist_op.dist_attr.process_mesh, ( + "The segment's ops should have same process_mesh." + ) seg_op_deps[struct_name].extend([i]) num_chunks = pp_degree * vpp_degree - assert ( - len(seg_op_deps) % num_chunks == 0 - ), f"The number of layers[{seg_method}] ({len(seg_op_deps)}) should be divided by part number ({num_chunks})." + assert len(seg_op_deps) % num_chunks == 0, ( + f"The number of layers[{seg_method}] ({len(seg_op_deps)}) should be divided by part number ({num_chunks})." + ) # Step2: analysis whether the pp_stage is non-decreasing among segments # 1. if non_decreasing is True, the ops' process_mesh will be changed by vpp strategy @@ -1634,9 +1634,9 @@ def _get_op_by_id(ops, id): input_name ) ) - assert ( - ref_dims_mapping is not None - ), f"[{input_name}] 's dims mapping is NONE" + assert ref_dims_mapping is not None, ( + f"[{input_name}] 's dims mapping is NONE" + ) grad_op_dist_attr.set_input_dims_mapping( input_name, ref_dims_mapping ) @@ -1671,7 +1671,9 @@ def _get_op_by_id(ops, id): output_name = grad_op.output_arg_names[0] assert ( output_name in grad_var_to_var[appended_grad_times] - ), f"sum op's output '{output_name}' has no corresponding var" + ), ( + f"sum op's output '{output_name}' has no corresponding var" + ) ref_fwd_var_name = grad_var_to_var[appended_grad_times][ output_name ] @@ -1755,9 +1757,9 @@ def _is_grad_var_name(name): return False def _get_forward_varname_from_grad_varname(grad_var_name): - assert _is_grad_var_name( - grad_var_name - ), f"[{grad_var_name}] is not a grad var name." + assert _is_grad_var_name(grad_var_name), ( + f"[{grad_var_name}] is not a grad var name." + ) return grad_var_name[: grad_var_name.find("@GRAD")] def _get_op_by_id(ops, id): @@ -1828,9 +1830,9 @@ def _complete_grad_op_with_forward_op(forward_op, grad_op, vars): input_name ) ) - assert ( - ref_dims_mapping is not None - ), f"[{input_name}] 's dims mapping is NONE" + assert ref_dims_mapping is not None, ( + f"[{input_name}] 's dims mapping is NONE" + ) grad_op_dist_attr.set_input_dims_mapping( input_name, ref_dims_mapping ) @@ -1973,9 +1975,9 @@ def infer_backward_op_partial_status( first_backward_op_idx = idx break - assert ( - first_backward_op_idx >= 0 and loss_op is not None - ), "No backward procedure found in this program." + assert first_backward_op_idx >= 0 and loss_op is not None, ( + "No backward procedure found in this program." + ) ops = list(serial_main_program.global_block().ops) vars = serial_main_program.global_block().vars @@ -1989,12 +1991,12 @@ def infer_backward_op_partial_status( # complete the initial grad loss op if idx == first_backward_op_idx: assert grad_op.type == "fill_constant" - assert ( - len(grad_op.input_arg_names) == 0 - ), f"first backward op should has only ONE output, but got [{len(grad_op.input_arg_names)}]" - assert ( - len(grad_op.output_arg_names) == 1 - ), f"first backward op should has only ONE output, but got [{len(grad_op.output_arg_names)}]" + assert len(grad_op.input_arg_names) == 0, ( + f"first backward op should has only ONE output, but got [{len(grad_op.input_arg_names)}]" + ) + assert len(grad_op.output_arg_names) == 1, ( + f"first backward op should has only ONE output, but got [{len(grad_op.output_arg_names)}]" + ) loss_var = vars[loss_op.output_arg_names[0]] loss_grad_var = vars[grad_op.output_arg_names[0]] @@ -2069,9 +2071,9 @@ def infer_backward_op_partial_status( if grad_op.type in ['sum', 'grad_add']: assert all(map(_is_grad_var_name, grad_op.input_arg_names)) output_name = grad_op.output_arg_names[0] - assert ( - output_name in grad_var_to_var - ), f"sum op's output '{output_name}' has no corresponding var" + assert output_name in grad_var_to_var, ( + f"sum op's output '{output_name}' has no corresponding var" + ) ref_fwd_var_name = grad_var_to_var[output_name] ref_fwd_var = vars[ref_fwd_var_name] ref_fwd_dist_attr = ( @@ -2297,12 +2299,12 @@ def complete_update_annotation(self, serial_main_program): ) if "Grad" in op.input_names and "Param" in ops[idx].input_names: - assert ( - len(op.input("Param")) == 1 - ), "Only support one-to-one now." - assert ( - len(op.input("Grad")) == 1 - ), "Only support one-to-one now." + assert len(op.input("Param")) == 1, ( + "Only support one-to-one now." + ) + assert len(op.input("Grad")) == 1, ( + "Only support one-to-one now." + ) param = vars[op.input("Param")[0]] grad_var = vars[op.input("Grad")[0]] diff --git a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py index 6383ca0fcb6b60..8fff701042872a 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py @@ -629,14 +629,14 @@ def _check_time(self, val): assert val >= 0, "Time must be greater than or equal to 0." def _check_memory(self, val): - assert ( - isinstance(val, int) and val >= 0 - ), "Memory must be int and greater than equal to 0." + assert isinstance(val, int) and val >= 0, ( + "Memory must be int and greater than equal to 0." + ) def _check_flops(self, val): - assert ( - isinstance(val, int) and val >= 0 - ), "FLOPs must be int and greater than equal to 0." + assert isinstance(val, int) and val >= 0, ( + "FLOPs must be int and greater than equal to 0." + ) @property def time(self): @@ -987,9 +987,9 @@ def calc_time_by_cost_model(op, cluster=None): var_name = op.output_arg_names[0] dtype = op.block._var_recursive(var_name).dtype device = cluster.get_device(0) - assert ( - device.type == DeviceType.GPU - ), "Only GPU device is supported currently." + assert device.type == DeviceType.GPU, ( + "Only GPU device is supported currently." + ) gflops = 0.0 if dtype == paddle.float64: diff --git a/python/paddle/distributed/auto_parallel/static/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/static/cost/estimate_cost.py index 95bd033f79c72e..c4552a38a88e41 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/estimate_cost.py @@ -37,9 +37,7 @@ def __init__( self._loop_count = loop_count self._global_cost = Cost() self._local_cost_mapping = {} - self._detailed_cost = ( - OrderedDict() - ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} + self._detailed_cost = OrderedDict() # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} self._bubble_time_mapping = {} self._ordered_ops = [] self.max_memories = {} @@ -286,9 +284,7 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping): memories = {} self.max_memories = {} - var_info = ( - {} - ) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} + var_info = {} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} for block in self.program.blocks: for op in block.ops: diff --git a/python/paddle/distributed/auto_parallel/static/cost/op_runtime_cost.py b/python/paddle/distributed/auto_parallel/static/cost/op_runtime_cost.py index 7561970f0a2538..2cbe7b9a44799e 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/op_runtime_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/op_runtime_cost.py @@ -271,19 +271,21 @@ def measure_program_real_op_cost( >>> measure_program_real_op_cost(program, verbose_level=1) ''' - assert isinstance( - program, Program - ), f'"program" should be a instance of "paddle.base.framework.Program" but got type "{type(program).__name__}".' + assert isinstance(program, Program), ( + f'"program" should be a instance of "paddle.base.framework.Program" but got type "{type(program).__name__}".' + ) supported_places = [ paddle.CUDAPlace, ] assert any( isinstance(place, supported_place) for supported_place in supported_places - ), f'Current place ({place}) does not support runtime profiling. "place" should be one of the following: {supported_places}.' - assert ( - isinstance(run_iters, int) and run_iters >= 1 - ), 'Invalid parameter run_iters set. run_iters should be an integer >= 1.' + ), ( + f'Current place ({place}) does not support runtime profiling. "place" should be one of the following: {supported_places}.' + ) + assert isinstance(run_iters, int) and run_iters >= 1, ( + 'Invalid parameter run_iters set. run_iters should be an integer >= 1.' + ) if run_iters == 1: warnings.warn( 'run_iters was set to 1, profiling results might be inaccurate due to outliers.' diff --git a/python/paddle/distributed/auto_parallel/static/cost_model.py b/python/paddle/distributed/auto_parallel/static/cost_model.py index d261b75b0d422c..1048d0b85bed9e 100644 --- a/python/paddle/distributed/auto_parallel/static/cost_model.py +++ b/python/paddle/distributed/auto_parallel/static/cost_model.py @@ -223,9 +223,9 @@ def __init__( self.optim_time = [] def _parse_sub_program(self, program, nodes, graph, cost_data, sub_idx): - assert ( - len(program.blocks) == 1 - ), "Program more than 1 block not supported." + assert len(program.blocks) == 1, ( + "Program more than 1 block not supported." + ) block = program.blocks[0] var_id = "lod_tensor_blocking_queue_0" diff --git a/python/paddle/distributed/auto_parallel/static/dist_context.py b/python/paddle/distributed/auto_parallel/static/dist_context.py index 9beeb11b0cb895..9ae5dbbd9c6559 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_context.py +++ b/python/paddle/distributed/auto_parallel/static/dist_context.py @@ -478,9 +478,9 @@ def initialize(self, with_graph=True, with_cpp=False, no_default=False): self.copy_dist_attr_from_program_to_graph() def add_process_mesh(self, process_mesh): - assert isinstance( - process_mesh, (ProcessMesh, core.ProcessMesh) - ), 'The type of dim_mapping must be ProcessMesh.' + assert isinstance(process_mesh, (ProcessMesh, core.ProcessMesh)), ( + 'The type of dim_mapping must be ProcessMesh.' + ) if process_mesh not in self.process_meshes: self._process_meshes.append(process_mesh) @@ -787,9 +787,9 @@ def _init_dist_attr_for_graph(self): ) dist_tensor = cur_dist_tensor self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id - assert ( - dist_tensor is not None - ), "Tensor must have a distributed tensor after the initialization for program." + assert dist_tensor is not None, ( + "Tensor must have a distributed tensor after the initialization for program." + ) serial_tensor_node_id = _node_id(node) new_dist_tensor = DistributedTensor( dist_tensor.serial_tensor, dist_tensor.dist_attr @@ -810,9 +810,9 @@ def _init_dist_attr_for_graph(self): ) dist_op = cur_dist_op self._node_id_to_op_id[_node_id(node)] = cur_op_id - assert ( - dist_op is not None - ), "Operator must have a distributed operator after the initialization for program." + assert dist_op is not None, ( + "Operator must have a distributed operator after the initialization for program." + ) serial_op_node_id = _node_id(node) new_dist_op = DistributedOperator( dist_op.serial_op, dist_op.dist_attr @@ -843,9 +843,9 @@ def copy_dist_attr_from_program_to_graph(self): cur_tensor_id, None ) dist_tensor = cur_dist_tensor - assert ( - dist_tensor is not None - ), "Tensor must have a distributed tensor after the initialization for program." + assert dist_tensor is not None, ( + "Tensor must have a distributed tensor after the initialization for program." + ) serial_tensor_node_id = _node_id(node) new_dist_tensor = DistributedTensor( dist_tensor.serial_tensor, dist_tensor.dist_attr @@ -865,9 +865,9 @@ def copy_dist_attr_from_program_to_graph(self): cur_op_id, None ) dist_op = cur_dist_op - assert ( - dist_op is not None - ), "Operator must have a distributed operator after the initialization for program." + assert dist_op is not None, ( + "Operator must have a distributed operator after the initialization for program." + ) serial_op_node_id = _node_id(node) new_dist_op = DistributedOperator( dist_op.serial_op, dist_op.dist_attr @@ -875,9 +875,9 @@ def copy_dist_attr_from_program_to_graph(self): self._dist_ops_for_graph[serial_op_node_id] = new_dist_op def copy_dist_attr_from_graph_to_program(self): - assert ( - self._is_initialized - ), "Both program and graph must be initialized." + assert self._is_initialized, ( + "Both program and graph must be initialized." + ) updated_tensors = {} all_nodes = self._serial_ordered_nodes process_meshes = [self.process_meshes[0]] @@ -1023,9 +1023,9 @@ def validate_dist_attr_for_program(self): for block in self.serial_main_program.blocks: for tensor in block.vars.values(): dist_tensor = self.get_dist_tensor_for_program(tensor) - assert ( - dist_tensor is not None - ), f"Tensor {dist_tensor.serial_tensor.name} does not have a distributed attribute." + assert dist_tensor is not None, ( + f"Tensor {dist_tensor.serial_tensor.name} does not have a distributed attribute." + ) if (dist_tensor is not None) and ( not dist_tensor.validate_dist_attr() ): @@ -1034,9 +1034,9 @@ def validate_dist_attr_for_program(self): ) for op in block.ops: dist_op = self.get_dist_op_for_program(op) - assert ( - dist_op is not None - ), f"Operator {dist_op.serial_op.type} does not have a distributed attribute." + assert dist_op is not None, ( + f"Operator {dist_op.serial_op.type} does not have a distributed attribute." + ) if (dist_op is not None) and (not dist_op.validate_dist_attr()): raise AssertionError( f"Operator {dist_op.serial_op.type} (id: {dist_op.serial_op.desc.id()}, original_id: {dist_op.serial_op.desc.original_id()}) has a wrong distributed attributes {dist_op.dist_attr} ." @@ -1214,18 +1214,18 @@ def parse_forward_blocks(self, program): for idx, block in enumerate(program.blocks): assert idx == block.idx, "index doesn't match" - assert ( - block.forward_block_idx == -1 - ), f"forward_block_idx of forward block [{idx}] is not [{block.forward_block_idx}]" + assert block.forward_block_idx == -1, ( + f"forward_block_idx of forward block [{idx}] is not [{block.forward_block_idx}]" + ) self.forward_indices.append(idx) self.nblock += 1 assert self.nblock >= 1 def parse_backward_blocks(self, program): - assert ( - 0 in self.forward_indices - ), f"forward block idx are{self.forward_indices}" + assert 0 in self.forward_indices, ( + f"forward block idx are{self.forward_indices}" + ) self.backward_to_forward_index_map[0] = 0 for idx, block in enumerate(program.blocks): diff --git a/python/paddle/distributed/auto_parallel/static/dist_loader.py b/python/paddle/distributed/auto_parallel/static/dist_loader.py index ce42ac68e7e064..06fb5fff919483 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/static/dist_loader.py @@ -186,9 +186,9 @@ def data_generator(): continue batch_size = array.shape[0] - assert ( - batch_size % self.dp_world_sizes[i] == 0 - ), f"batch_size [{batch_size}] is not divisible by dp_world_size [{self.dp_world_sizes[i]}]" + assert batch_size % self.dp_world_sizes[i] == 0, ( + f"batch_size [{batch_size}] is not divisible by dp_world_size [{self.dp_world_sizes[i]}]" + ) partial_data.append( np.split(array, self.dp_world_sizes[i])[ self.dp_ranks[i] diff --git a/python/paddle/distributed/auto_parallel/static/dist_op.py b/python/paddle/distributed/auto_parallel/static/dist_op.py index 8733a95b25d47e..af473eadc09d9f 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_op.py +++ b/python/paddle/distributed/auto_parallel/static/dist_op.py @@ -217,9 +217,9 @@ def __call__(self, *args, **kwargs): tensor_to_dims_mapping = {} index = 0 if self._in_dims_mappings: - assert len(args) + len(kwargs) == len( - self._in_dims_mappings - ), f"The length of dims_mapping {len(self._in_dims_mappings)} does not matching the length output {len(args) + len(kwargs)}." + assert len(args) + len(kwargs) == len(self._in_dims_mappings), ( + f"The length of dims_mapping {len(self._in_dims_mappings)} does not matching the length output {len(args) + len(kwargs)}." + ) for arg in args: if isinstance(arg, Variable) and self._in_dims_mappings: tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index] @@ -248,9 +248,9 @@ def __call__(self, *args, **kwargs): raise ValueError("Unrecognized output.") if self._out_dims_mappings: - assert len(new_output) == len( - self._out_dims_mappings - ), f"The length of dims_mapping {len(self._out_dims_mappings)} does not matching the length output {len(new_output)}." + assert len(new_output) == len(self._out_dims_mappings), ( + f"The length of dims_mapping {len(self._out_dims_mappings)} does not matching the length output {len(new_output)}." + ) for i, item in enumerate(new_output): if isinstance(item, Variable) and self._out_dims_mappings: tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i] @@ -282,7 +282,9 @@ def __call__(self, *args, **kwargs): ) assert verify_shard_spec( shard_spec, tensor_shape, self._process_mesh - ), f"For tensor {name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {self._process_mesh}." + ), ( + f"For tensor {name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {self._process_mesh}." + ) tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.mark_annotated("dims_mapping") for name in dist_op.serial_op.output_arg_names: @@ -306,7 +308,9 @@ def __call__(self, *args, **kwargs): ) assert verify_shard_spec( shard_spec, tensor_shape, self._process_mesh - ), f"For tensor {name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {self._process_mesh}." + ), ( + f"For tensor {name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {self._process_mesh}." + ) tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.mark_annotated("dims_mapping") dist_op.dist_attr.process_mesh = self._process_mesh diff --git a/python/paddle/distributed/auto_parallel/static/dist_tensor.py b/python/paddle/distributed/auto_parallel/static/dist_tensor.py index 7420ad1f014f9f..179dd08f858c4c 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/static/dist_tensor.py @@ -148,9 +148,9 @@ def get_local_shard( local_sizes = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, topology, processes, rank, shard_sizes ) - assert len(local_sizes) == len( - local_offsets - ), f"The length of local_sizes must be equal to local_offsets, but got {len(local_sizes)} and {len(local_offsets)}." + assert len(local_sizes) == len(local_offsets), ( + f"The length of local_sizes must be equal to local_offsets, but got {len(local_sizes)} and {len(local_offsets)}." + ) local_end_offsets = [ x[0] + x[1] for x in zip(local_offsets, local_sizes) @@ -359,9 +359,9 @@ def _copy_kwargs(serial_tensor): def local_tensor(self, rank=None): rank = paddle.distributed.get_rank() if rank is None else rank - assert ( - rank in self._local_tensor_map - ), f"The rank {rank} local tensor has not been created." + assert rank in self._local_tensor_map, ( + f"The rank {rank} local tensor has not been created." + ) return self._local_tensor_map[rank] def __deepcopy__(self, memo): diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 42b040c349ba5c..27b26c133c9dbb 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -284,9 +284,9 @@ def __init__( self._strategy.pipeline.enable and self._strategy.pipeline.schedule_mode == "1F1B" ): - assert ( - os.getenv("CUDA_MODULE_LOADING") != "LAZY" - ), "EXP_CUDA_MODULE_LOADING_LAZY not supported in 1F1B pipeline." + assert os.getenv("CUDA_MODULE_LOADING") != "LAZY", ( + "EXP_CUDA_MODULE_LOADING_LAZY not supported in 1F1B pipeline." + ) self.history = None @@ -471,28 +471,28 @@ def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels): raise ValueError("Only support static graph mode.") if inputs_spec: - assert isinstance( - inputs_spec, list - ), f"inputs should be list, but received {type(inputs_spec)}" - assert isinstance( - inputs, list - ), f"inputs should be list, but received {type(inputs)}" - assert len(inputs_spec) == len( - inputs - ), "the number of `inputs_spec` should be equal to `inputs`'s." + assert isinstance(inputs_spec, list), ( + f"inputs should be list, but received {type(inputs_spec)}" + ) + assert isinstance(inputs, list), ( + f"inputs should be list, but received {type(inputs)}" + ) + assert len(inputs_spec) == len(inputs), ( + "the number of `inputs_spec` should be equal to `inputs`'s." + ) for input_spec, input in zip(inputs_spec, inputs): if input_spec.shape != input.shape: input.desc.set_shape(input_spec.shape) if labels_spec: - assert isinstance( - labels_spec, list - ), f"labels should be list, but received {type(labels_spec)}" - assert isinstance( - labels, list - ), f"labels should be list, but received {type(labels)}" - assert len(labels_spec) == len( - labels - ), "the number of `labels_spec` should be equal to `labels`'s." + assert isinstance(labels_spec, list), ( + f"labels should be list, but received {type(labels_spec)}" + ) + assert isinstance(labels, list), ( + f"labels should be list, but received {type(labels)}" + ) + assert len(labels_spec) == len(labels), ( + "the number of `labels_spec` should be equal to `labels`'s." + ) for label_spec, label in zip(labels_spec, labels): if label_spec.shape != label.shape: label.desc.set_shape(label_spec.shape) @@ -562,18 +562,18 @@ def _prepare_feed(self, data, user_feeds, mode): else: raise ValueError(f"Unsupported data {data}") if user_feeds is not None: - assert isinstance( - user_feeds, dict - ), f"user_feeds must be a dict, but receive {type(user_feeds).__name__}" + assert isinstance(user_feeds, dict), ( + f"user_feeds must be a dict, but receive {type(user_feeds).__name__}" + ) for name, data in user_feeds.items(): feeds[name] = data return feeds def _prepare_fetch(self, user_fetches, mode): if user_fetches is not None: - assert isinstance( - user_fetches, list - ), f"user_fetches must be a list, but receive {type(user_fetches).__name__}" + assert isinstance(user_fetches, list), ( + f"user_fetches must be a list, but receive {type(user_fetches).__name__}" + ) fetch_names = [] fetch_indices = [] @@ -1149,9 +1149,9 @@ def _build(self, mode): if mode != "predict" and self._loss: assert isinstance( self._loss, paddle.nn.Layer - ) or callable( - self._loss - ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function." + ) or callable(self._loss), ( + "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function." + ) self._losses = auto_utils.to_list( self._loss(*(outputs + self._labels)) ) @@ -1164,9 +1164,9 @@ def _build(self, mode): ) ) elif mode == "train": - assert isinstance( - self._loss, Variable - ), "the type of `loss` of the Engine arguments should be Variable." + assert isinstance(self._loss, Variable), ( + "the type of `loss` of the Engine arguments should be Variable." + ) self._losses = auto_utils.to_list(self._loss) # TODO(zhiqiu): distributed_context is no longer used in pir_program @@ -1237,7 +1237,9 @@ def _build(self, mode): self._json_config, ) self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale - self._dist_contexts[mode].gradient_scale_using_allreduce_avg = ( + self._dist_contexts[ + mode + ].gradient_scale_using_allreduce_avg = ( self._strategy.gradient_scale_using_allreduce_avg ) self._fwd_main_progs[mode] = serial_main_prog.clone() @@ -1270,9 +1272,9 @@ def _optimization_tuning(self, mode, dataset, batch_size): if self._tuning.run_after_tuning: # update the strategy - self._dist_contexts[mode]._strategy = ( - self._optimization_tuner.get_best_config() - ) + self._dist_contexts[ + mode + ]._strategy = self._optimization_tuner.get_best_config() def _plan(self, mode): if self._planned_mode is None: @@ -1333,9 +1335,9 @@ def _init_dist_context(self, mode): for ib, block in enumerate(origin_main_prog.blocks): for iop, op in enumerate(block.ops): ref_op = ref_blocks[ib].ops[iop] - assert ( - op.type == ref_op.type - ), f"'{mode}' mode op '{op.type}' is different with '{ref_mode}' op '{ref_op.type}'. " + assert op.type == ref_op.type, ( + f"'{mode}' mode op '{op.type}' is different with '{ref_mode}' op '{ref_op.type}'. " + ) ref_op_dist_attr = ( ref_dist_context.get_op_dist_attr_for_program(ref_op) ) @@ -1412,9 +1414,9 @@ def _initialize(self, mode, init_parameters=True): for op in dist_main_prog.global_block().ops: if op.name() == "pd_op.data": var_name = op.str_attr("name") - assert ( - var_name not in name_map_value - ), f"The value {var_name} in {op} is already exist" + assert var_name not in name_map_value, ( + f"The value {var_name} in {op} is already exist" + ) name_map_value[var_name] = op.result(0) del_ops = [] block = startup_prog.global_block() @@ -2078,9 +2080,9 @@ def prepare( if self._orig_startup_prog is None: self._orig_startup_prog = static.default_startup_program() else: - assert ( - self._inputs_spec and self._labels_spec - ), "Please call the dataloader(...) before calling prepare(...)" + assert self._inputs_spec and self._labels_spec, ( + "Please call the dataloader(...) before calling prepare(...)" + ) self._inputs_spec, self._labels_spec = inputs_spec, labels_spec self._inputs, self._labels = inputs, labels @@ -2265,12 +2267,12 @@ def _validate_batch_size(self, batch_size): if batch_size is None: return None - assert ( - len(set(self._dp_world_sizes)) == 1 - ), f"DistributedBatchSampler only support one data parallel group, but got [{len(set(self._dp_world_sizes))}] different data parallel groups" - assert ( - batch_size % self._dp_world_sizes[0] == 0 - ), f"batch_size [{batch_size}] is not divisible by dp_world_size [{self._dp_world_sizes[0]}]" + assert len(set(self._dp_world_sizes)) == 1, ( + f"DistributedBatchSampler only support one data parallel group, but got [{len(set(self._dp_world_sizes))}] different data parallel groups" + ) + assert batch_size % self._dp_world_sizes[0] == 0, ( + f"batch_size [{batch_size}] is not divisible by dp_world_size [{self._dp_world_sizes[0]}]" + ) return batch_size // self._dp_world_sizes[0] def _validate_batch(self, batch): @@ -2311,9 +2313,9 @@ def _validate_spec(self, specs): ) if self._acc_steps > 1: shape = list(spec.shape) - assert ( - shape[0] % self._acc_steps == 0 - ), f"Requires batch_size[{spec.shape[0]}] to be divisible by k_steps[{self._acc_steps}]." + assert shape[0] % self._acc_steps == 0, ( + f"Requires batch_size[{spec.shape[0]}] to be divisible by k_steps[{self._acc_steps}]." + ) shape[0] //= self._acc_steps spec.shape = shape return specs or [] @@ -2341,9 +2343,9 @@ def _metrics_name(self): return metrics_name def _switch_mode(self, mode): - assert ( - mode in self._dist_contexts - ), f"{mode} model is not ready, please call `prepare()` first." + assert mode in self._dist_contexts, ( + f"{mode} model is not ready, please call `prepare()` first." + ) self.to_mode(mode) def to_mode(self, mode: _Mode) -> None: diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index e4d7592096a813..95d5e66a983f06 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -337,15 +337,15 @@ def apply_optimizer(self, optimizer): def _verify_optimizer(self, optimizer): assert optimizer is not None - assert hasattr( - optimizer, "minimize" - ), "Optimizer must have minimize() method." - assert ( - self.proxy_layer.mode == 'train' - ), f"Required mode == 'train', but received '{self.proxy_layer.mode}'" - assert ( - len(self.loss_vars) == 1 - ), f"Required len(loss_vars) == 1, but received len(loss_vars) = {len(self.loss_vars)}" + assert hasattr(optimizer, "minimize"), ( + "Optimizer must have minimize() method." + ) + assert self.proxy_layer.mode == 'train', ( + f"Required mode == 'train', but received '{self.proxy_layer.mode}'" + ) + assert len(self.loss_vars) == 1, ( + f"Required len(loss_vars) == 1, but received len(loss_vars) = {len(self.loss_vars)}" + ) def to(self, mode): """ @@ -353,9 +353,9 @@ def to(self, mode): """ assert mode in ['train', 'eval', 'predict'] func = getattr(self.proxy_layer, '_' + mode) - assert isinstance( - func, StaticFunction - ), "Please call build_program(mode) firstly." + assert isinstance(func, StaticFunction), ( + "Please call build_program(mode) firstly." + ) self.proxy_layer.set_mode(mode) def static_func(self): @@ -419,9 +419,9 @@ def init_pir(self, main_program, place): value_name = dy_param_name_to_pir_param_name[param.name] value = value_name_to_value[value_name] # get param_var's dist_attr - assert ( - value.is_dist_dense_tensor_type() - ), f"param [{value.name}] is not dist tensor type" + assert value.is_dist_dense_tensor_type(), ( + f"param [{value.name}] is not dist tensor type" + ) dist_attr = { "dims_mapping": value.dist_attr().dims_mapping, "process_shape": value.dist_attr().process_mesh.shape, @@ -536,9 +536,9 @@ def init(self, main_program, place, dist_context): if param.dtype in [paddle.float16, paddle.bfloat16]: continue scope_tensor = global_scope().var(param.name).get_tensor() - assert ( - scope_var and scope_tensor._is_initialized() - ), f"Parameter: {param.name} is not put into global_scope or not initialized." + assert scope_var and scope_tensor._is_initialized(), ( + f"Parameter: {param.name} is not put into global_scope or not initialized." + ) param_used = param # For the params without dist_attr. # NOTE(lizhiyu): In principle, each param should have dist_attr. diff --git a/python/paddle/distributed/auto_parallel/static/mapper.py b/python/paddle/distributed/auto_parallel/static/mapper.py index 7e9e1db86428ca..ba233de544a18f 100644 --- a/python/paddle/distributed/auto_parallel/static/mapper.py +++ b/python/paddle/distributed/auto_parallel/static/mapper.py @@ -142,9 +142,9 @@ def analyze_comm_requirements_from_op(op, rank, g_process_group_map): comm_volume = get_comm_volume(op, rank, tgt_rank) if comm_volume is not None: comm_requirements_to_ranks[tgt_rank] = {} - comm_requirements_to_ranks[tgt_rank][ - "comm_volume" - ] = comm_volume + comm_requirements_to_ranks[tgt_rank]["comm_volume"] = ( + comm_volume + ) elif is_p2p_comm_op(op): tgt_rank = op.attr("peer") comm_volume = get_comm_volume(op, rank, tgt_rank) @@ -170,9 +170,9 @@ def analyze_requirements_for_program(src_info, rank): ) for tgt_rank, link_info in cur_comm_requirements_to_ranks.items(): if tgt_rank in comm_requirements_to_ranks: - comm_requirements_to_ranks[tgt_rank][ - "comm_volume" - ] += link_info["comm_volume"] + comm_requirements_to_ranks[tgt_rank]["comm_volume"] += ( + link_info["comm_volume"] + ) else: comm_requirements_to_ranks[tgt_rank] = {} comm_requirements_to_ranks[tgt_rank]["comm_volume"] = ( @@ -266,9 +266,9 @@ def select_unvisited_rank_node(rank_node_list): cur_rank_node["device"] = device_node["device"] cur_device_node = device_node break - assert ( - cur_device_node - ), "Cannot find a device to satisfy the requirement." + assert cur_device_node, ( + "Cannot find a device to satisfy the requirement." + ) nbr_rank_edges = [] for nbr_rank_node_id, nbr_rank_edge in process_graph.adjs[ diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 4a30d36528ca33..c209c091f142ee 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -107,9 +107,9 @@ def impls(self): return self._impls def register_impl(self, dist_impl): - assert ( - self.type == dist_impl.type - ), "Op type of container must be same as that of the implementation." + assert self.type == dist_impl.type, ( + "Op type of container must be same as that of the implementation." + ) impl_idx = len(self.impls) dist_impl.idx = impl_idx self._impls.append(dist_impl) @@ -353,9 +353,9 @@ def is_parameter_related(varname, block, dist_context=None): varname = varname[: varname.index(".cast_bf")] if ".quantized" in varname: varname = varname[: varname.index(".quantized")] - assert block._find_var_recursive( - varname - ), f"cannot find var {varname} in cur block" + assert block._find_var_recursive(varname), ( + f"cannot find var {varname} in cur block" + ) var = block._var_recursive(varname) # NOTE(hack method): to find the param which is resharded if dist_context and "@RESHARD" in varname: @@ -551,9 +551,9 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): added_ops.append(scale_op) dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name) - assert ( - dims_mapping is not None - ), f"Unexpected: dims_mapping of output [{grad_var.name}] of op [{op_dist_attr.op_type}] is None" + assert dims_mapping is not None, ( + f"Unexpected: dims_mapping of output [{grad_var.name}] of op [{op_dist_attr.op_type}] is None" + ) # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor for new_op in added_ops: new_op_attr = OperatorDistAttr() @@ -586,9 +586,9 @@ def get_partial_groups(dist_ctx, op, out_grad_names, rank): if partial_dims is None: partial_dims = var_dist_attr._partial_dims() else: - assert ( - partial_dims == var_dist_attr._partial_dims() - ), f"Partial dims of outputs {out_grad_names} of op [{op.type}] is not consistent" + assert partial_dims == var_dist_attr._partial_dims(), ( + f"Partial dims of outputs {out_grad_names} of op [{op.type}] is not consistent" + ) partial_dims = list(partial_dims) partial_dims.sort() diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/static/operators/dist_check_finite_and_unscale.py index 8198643130aa94..8165b2f8526f9d 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_check_finite_and_unscale.py @@ -84,9 +84,9 @@ def backward(ctx, *args, **kwargs): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) assert rank_id in dist_attr.process_mesh.process_ids @@ -97,20 +97,20 @@ def backward(ctx, *args, **kwargs): 'FoundInfinite' ) - assert ( - len(kwargs['Scale']) == 1 - ), "check_finite_and_unscale input Scale take 1 variable but got {}".format( - kwargs['Scale'] + assert len(kwargs['Scale']) == 1, ( + "check_finite_and_unscale input Scale take 1 variable but got {}".format( + kwargs['Scale'] + ) ) - assert ( - len(kwargs['FoundInfinite']) == 1 - ), "check_finite_and_unscale input FoundInfinite take 1 variable but got {}".format( - kwargs['FoundInfinite'] + assert len(kwargs['FoundInfinite']) == 1, ( + "check_finite_and_unscale input FoundInfinite take 1 variable but got {}".format( + kwargs['FoundInfinite'] + ) ) - assert len(kwargs['X']) == len( - kwargs['Out'] - ), "check_finite_and_unscale got [{}] X and [{}] Out, which are supposed to be equal".format( - len(kwargs['X']), len(kwargs['Out']) + assert len(kwargs['X']) == len(kwargs['Out']), ( + "check_finite_and_unscale got [{}] X and [{}] Out, which are supposed to be equal".format( + len(kwargs['X']), len(kwargs['Out']) + ) ) filter_vars = [] diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_concat.py b/python/paddle/distributed/auto_parallel/static/operators/dist_concat.py index 1f4754ca22c5bb..6dd63d5c348f74 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_concat.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_concat.py @@ -32,9 +32,9 @@ def update_dims_mapping(dist_op): op_desc = dist_op.serial_op.desc axis_tensor = op_desc.input('AxisTensor') - assert ( - len(axis_tensor) == 0 - ), "Please use axis attr instead of AxisTensor" + assert len(axis_tensor) == 0, ( + "Please use axis attr instead of AxisTensor" + ) input_arg_names = op_desc.input_arg_names() output_arg_names = op_desc.output_arg_names() diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py b/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py index 5e1660dbcdfcd2..9ec98e56d9ec96 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py @@ -116,12 +116,12 @@ def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): axis = axis + logits_ndim if axis < 0 else axis if is_dim_shard(logits_dims_mapping[axis]): - assert ( - soft_label is False - ), "parallel_cross_entropy does not support soft_label now." - assert ( - axis == logits_ndim - 1 - ), "parallel_cross_entropy can only support shard on the last dim now." + assert soft_label is False, ( + "parallel_cross_entropy does not support soft_label now." + ) + assert axis == logits_ndim - 1, ( + "parallel_cross_entropy can only support shard on the last dim now." + ) op_dist_attr.impl_idx = 1 else: op_dist_attr.impl_idx = 0 @@ -162,9 +162,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs assert 'Logits' in kwargs, "input [Logits] is not given" @@ -172,12 +172,12 @@ def forward(ctx, *args, **kwargs): assert 'Loss' in kwargs, "output [Loss] is not given" assert 'Softmax' in kwargs, "output [Softmax] is not given" - assert ( - len(kwargs['Logits']) == 1 - ), "input [Logits] take 1 variable but got {}".format(kwargs['Logits']) - assert ( - len(kwargs['Label']) == 1 - ), "input [Label] take 1 variable but got {}".format(kwargs['Label']) + assert len(kwargs['Logits']) == 1, ( + "input [Logits] take 1 variable but got {}".format(kwargs['Logits']) + ) + assert len(kwargs['Label']) == 1, ( + "input [Label] take 1 variable but got {}".format(kwargs['Label']) + ) logits_var = main_block._var_recursive(kwargs['Logits'][0]) label_var = main_block._var_recursive(kwargs['Label'][0]) @@ -228,9 +228,9 @@ def backward(ctx, *args, **kwargs): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - op_dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) # check validation of inputs / outputs assert 'Softmax' in kwargs, "input [Logits] is not given" @@ -238,21 +238,21 @@ def backward(ctx, *args, **kwargs): assert 'Loss@GRAD' in kwargs, "input [Loss@GRAD] is not given" assert 'Logits@GRAD' in kwargs, "output [Logits@GRAD] is not given" - assert ( - len(kwargs['Softmax']) == 1 - ), "input [Softmax] take 1 variable but got {}".format( - kwargs['Softmax'] - ) - assert ( - len(kwargs['Label']) == 1 - ), "input [Label] take 1 variable but got {}".format(kwargs['Label']) - assert ( - len(kwargs['Loss@GRAD']) == 1 - ), "input [Loss@GRAD] take 1 variable but got {}".format(kwargs['Out']) - assert ( - len(kwargs['Logits@GRAD']) == 1 - ), "output [Logits@GRAD] take 1 variable but got {}".format( - kwargs['Logits@GRAD'] + assert len(kwargs['Softmax']) == 1, ( + "input [Softmax] take 1 variable but got {}".format( + kwargs['Softmax'] + ) + ) + assert len(kwargs['Label']) == 1, ( + "input [Label] take 1 variable but got {}".format(kwargs['Label']) + ) + assert len(kwargs['Loss@GRAD']) == 1, ( + "input [Loss@GRAD] take 1 variable but got {}".format(kwargs['Out']) + ) + assert len(kwargs['Logits@GRAD']) == 1, ( + "output [Logits@GRAD] take 1 variable but got {}".format( + kwargs['Logits@GRAD'] + ) ) # replicate op in dist program copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs) @@ -285,9 +285,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs assert 'Logits' in kwargs, "input [Logits] is not given" @@ -295,12 +295,12 @@ def forward(ctx, *args, **kwargs): assert 'Loss' in kwargs, "output [Loss] is not given" assert 'Softmax' in kwargs, "output [Softmax] is not given" - assert ( - len(kwargs['Logits']) == 1 - ), "input [Logits] take 1 variable but got {}".format(kwargs['Logits']) - assert ( - len(kwargs['Label']) == 1 - ), "input [Label] take 1 variable but got {}".format(kwargs['Label']) + assert len(kwargs['Logits']) == 1, ( + "input [Logits] take 1 variable but got {}".format(kwargs['Logits']) + ) + assert len(kwargs['Label']) == 1, ( + "input [Label] take 1 variable but got {}".format(kwargs['Label']) + ) logits_var = main_block._var_recursive(kwargs['Logits'][0]) label_var = main_block._var_recursive(kwargs['Label'][0]) @@ -395,9 +395,9 @@ def backward(ctx, *args, **kwargs): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - op_dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) # check validation of inputs / outputs assert 'Softmax' in kwargs, "input [Softmax] is not given" @@ -405,23 +405,23 @@ def backward(ctx, *args, **kwargs): assert 'Loss@GRAD' in kwargs, "input [Loss@GRAD] is not given" assert 'Logits@GRAD' in kwargs, "output [Logits@GRAD] is not given" - assert ( - len(kwargs['Softmax']) == 1 - ), "input [Softmax] take 1 variable but got {}".format( - kwargs['Softmax'] - ) - assert ( - len(kwargs['Label']) == 1 - ), "input [Label] take 1 variable but got {}".format(kwargs['Label']) - assert ( - len(kwargs['Loss@GRAD']) == 1 - ), "input [Loss@GRAD] take 1 variable but got {}".format( - kwargs['Loss@GRAD'] - ) - assert ( - len(kwargs['Logits@GRAD']) == 1 - ), "output [Logits@GRAD] take 1 variable but got {}".format( - kwargs['Logits@GRAD'] + assert len(kwargs['Softmax']) == 1, ( + "input [Softmax] take 1 variable but got {}".format( + kwargs['Softmax'] + ) + ) + assert len(kwargs['Label']) == 1, ( + "input [Label] take 1 variable but got {}".format(kwargs['Label']) + ) + assert len(kwargs['Loss@GRAD']) == 1, ( + "input [Loss@GRAD] take 1 variable but got {}".format( + kwargs['Loss@GRAD'] + ) + ) + assert len(kwargs['Logits@GRAD']) == 1, ( + "output [Logits@GRAD] take 1 variable but got {}".format( + kwargs['Logits@GRAD'] + ) ) # got dist attribute info diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_default.py b/python/paddle/distributed/auto_parallel/static/operators/dist_default.py index 793b037b10389f..9e3f3200d47af0 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_default.py @@ -60,9 +60,9 @@ def prim_operator_data_parallel_functor(ctx, src_op): var_name = src_op.output_arg_names[0] if var_name in ctx.grads_params: - assert ( - var_name not in ctx.synced_gradient - ), f"in primitive mode, grad is already {var_name} synced" + assert var_name not in ctx.synced_gradient, ( + f"in primitive mode, grad is already {var_name} synced" + ) ctx.synced_gradient.add(var_name) sync_group = new_process_group(ctx.data_parallel_group) @@ -119,18 +119,18 @@ def update_dims_mapping(dist_op): num_inputs = len(input_arg_names) input_specs = [] for i in range(num_inputs): - assert not is_parameter_related( - input_arg_names[i], main_block - ), f"input {input_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule." + assert not is_parameter_related(input_arg_names[i], main_block), ( + f"input {input_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule." + ) input_specs.append( get_dist_tensor_spec(dist_op, input_arg_names[i]) ) num_outputs = len(output_arg_names) output_specs = [] for i in range(num_outputs): - assert not is_parameter_related( - output_arg_names[i], main_block - ), f"output {output_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule." + assert not is_parameter_related(output_arg_names[i], main_block), ( + f"output {output_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule." + ) output_specs.append( get_dist_tensor_spec(dist_op, output_arg_names[i], False) ) @@ -632,9 +632,9 @@ def backward(ctx, *args, **kwargs): main_block = dist_op_context.work_block backward_op = dist_op_context.cur_src_op dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) rank_id = dist_op_context.rank_id # check validation of inputs / outputs diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py index dc6affc766f647..374154ab2a6897 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py @@ -109,17 +109,17 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute: # check validation of inputs / outputs assert 'X' in kwargs, "input [{}] is not given".format('X') - assert ( - len(kwargs['X']) == 1 - ), "input X should be only one tensor but got {}".format( - kwargs['X'] + assert len(kwargs['X']) == 1, ( + "input X should be only one tensor but got {}".format( + kwargs['X'] + ) ) assert 'Seed' in kwargs, "input [{}] is not given".format('Seed') diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py index 810e88a7e22bba..04b09b62f9200f 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py @@ -47,13 +47,13 @@ def __init__(self, op_type): def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - len(op_desc.input_arg_names()) >= 1 - ), f"elementwise op [{op_desc.type}] has [{len(op_desc.input_arg_names())}] inputs" + assert len(op_desc.input_arg_names()) >= 1, ( + f"elementwise op [{op_desc.type}] has [{len(op_desc.input_arg_names())}] inputs" + ) input_arg_names = op_desc.input_arg_names() - assert ( - len(op_desc.output_arg_names()) == 1 - ), f"elementwise op [{dist_op.serial_op}] has [{len(op_desc.output_arg_names())}] outputs" + assert len(op_desc.output_arg_names()) == 1, ( + f"elementwise op [{dist_op.serial_op}] has [{len(op_desc.output_arg_names())}] outputs" + ) output_arg_name = op_desc.output_arg_names()[0] num_inputs = len(input_arg_names) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py index 7bd7b222ed760a..438a384f0e0565 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py @@ -66,9 +66,9 @@ def __init__(self, op_type): def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - dist_op.serial_op.type == "lookup_table_v2" - ), f"{dist_op.serial_op.type} is not supported by dist embedding yet." + assert dist_op.serial_op.type == "lookup_table_v2", ( + f"{dist_op.serial_op.type} is not supported by dist embedding yet." + ) x_name = op_desc.input('Ids')[0] w_name = op_desc.input('W')[0] @@ -129,9 +129,9 @@ def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): - assert ( - len(Ids_var.shape) == 3 - ), f"input Ids to lookup_table should have 3 dimensions but got [{Ids_var.name}] with shape [{Ids_var.shape}]" + assert len(Ids_var.shape) == 3, ( + f"input Ids to lookup_table should have 3 dimensions but got [{Ids_var.name}] with shape [{Ids_var.shape}]" + ) if not Ids_var.stop_gradient: raise NotImplementedError( 'Requiring the gradient of Ids of lookup_table(v1) dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).' @@ -421,29 +421,29 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'Out' in kwargs, "output [{}] is not given".format('Out') - assert ( - len(kwargs['Ids']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Ids'] + assert len(kwargs['Ids']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Ids'] + ) ) - assert ( - len(kwargs['W']) == 1 - ), "row_parallel_embedding input W take 1 variable but got {}".format( - kwargs['W'] + assert len(kwargs['W']) == 1, ( + "row_parallel_embedding input W take 1 variable but got {}".format( + kwargs['W'] + ) ) - assert ( - len(kwargs['Out']) == 1 - ), "row_parallel_embedding output Out take 1 variable but got {}".format( - kwargs['Out'] + assert len(kwargs['Out']) == 1, ( + "row_parallel_embedding output Out take 1 variable but got {}".format( + kwargs['Out'] + ) ) Ids_var = main_block._var_recursive(kwargs['Ids'][0]) @@ -458,9 +458,9 @@ def forward(ctx, *args, **kwargs): embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[0] - assert ( - embedding_row_dim_mapping >= 0 - ), f"row_parallel_embedding's row should be divided by a specific mesh axis, but got [{embedding_row_dim_mapping}]" + assert embedding_row_dim_mapping >= 0, ( + f"row_parallel_embedding's row should be divided by a specific mesh axis, but got [{embedding_row_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -576,9 +576,9 @@ def backward(ctx, *args, **kwargs): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in dist_attr.process_mesh.process_ids: @@ -591,25 +591,25 @@ def backward(ctx, *args, **kwargs): assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD') - assert ( - len(kwargs['Ids']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Ids'] + assert len(kwargs['Ids']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Ids'] + ) ) - assert ( - len(kwargs['W']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['W'] + assert len(kwargs['W']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['W'] + ) ) - assert ( - len(kwargs['Out@GRAD']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Out'] + assert len(kwargs['Out@GRAD']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out'] + ) ) - assert ( - len(kwargs['W@GRAD']) == 1 - ), "row_parallel_embedding output Ids take 1 variable but got {}".format( - kwargs['W@GRAD'] + assert len(kwargs['W@GRAD']) == 1, ( + "row_parallel_embedding output Ids take 1 variable but got {}".format( + kwargs['W@GRAD'] + ) ) Ids_var = main_block._var_recursive(kwargs['Ids'][0]) @@ -620,9 +620,9 @@ def backward(ctx, *args, **kwargs): embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( Weight_var.name )[0] - assert ( - embedding_row_dim_mapping >= 0 - ), f"row_parallel_embedding's row should be divided by a specific mesh axis, but got [{embedding_row_dim_mapping}]" + assert embedding_row_dim_mapping >= 0, ( + f"row_parallel_embedding's row should be divided by a specific mesh axis, but got [{embedding_row_dim_mapping}]" + ) process_mesh_shape = dist_attr.process_mesh.shape process_mesh_group = dist_attr.process_mesh.process_ids diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py index 10d58ed678ae28..ac77b725dae737 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py @@ -60,9 +60,9 @@ def forward(ctx, *args, **kwargs): and not op_dist_attr.is_recompute and rank_id in op_dist_attr.process_mesh.process_ids ): - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) if ( len(kwargs.get('fixed_seed_offset', [])) > 0 diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_attention.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_attention.py index 6c7ba951980a76..87ed3a6773c433 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_attention.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_attention.py @@ -172,9 +172,9 @@ def forward(ctx, *args, **kwargs): qkv_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(qkv_w)[ head_axis ] - assert ( - qkv_w_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{qkv_w_col_dim_mapping}]" + assert qkv_w_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{qkv_w_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -209,9 +209,9 @@ def backward(ctx, *args, **kwargs): # infer logic comm presentation out_w = src_op.input('OutLinearW')[0] out_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(out_w)[-1] - assert ( - out_w_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{out_w_col_dim_mapping}]" + assert out_w_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{out_w_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py index 37d99553d85d18..57d735277415cc 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py @@ -72,9 +72,9 @@ def forward(ctx, *args, **kwargs): op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute: - assert ( - op_dist_attr is not None - ), f"forward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"forward op [{src_op}] don't have dist attribute !" + ) assert 'seed_tensor' in kwargs, "input [{}] is not given".format( 'seed_tensor' diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_feedforward.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_feedforward.py index 1df1bf88490267..369045870299ae 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_feedforward.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_feedforward.py @@ -163,9 +163,9 @@ def forward(ctx, *args, **kwargs): linear1_weight_col_dim_mapping = op_dist_attr.get_input_dims_mapping( linear1_weight )[-1] - assert ( - linear1_weight_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{linear1_weight_col_dim_mapping}]" + assert linear1_weight_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{linear1_weight_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -202,9 +202,9 @@ def backward(ctx, *args, **kwargs): linear2_weight_col_dim_mapping = op_dist_attr.get_input_dims_mapping( linear2_weight )[-1] - assert ( - linear2_weight_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{linear2_weight_col_dim_mapping}]" + assert linear2_weight_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{linear2_weight_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py index 12408c282a8ceb..49c39bb759c2e0 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py @@ -315,9 +315,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in dist_attr.process_mesh.process_ids: @@ -328,25 +328,25 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') - assert ( - len(kwargs['Y']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Y'] + assert len(kwargs['Y']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Y'] + ) ) - assert ( - len(kwargs['X']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['X'] + assert len(kwargs['X']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['X'] + ) ) - assert ( - len(kwargs['Out@GRAD']) == 1 - ), "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Out'] + assert len(kwargs['Out@GRAD']) == 1, ( + "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out'] + ) ) - assert ( - len(kwargs['Y@GRAD']) == 1 - ), "row_parallel_embedding output Ids take 1 variable but got {}".format( - kwargs['Y@GRAD'] + assert len(kwargs['Y@GRAD']) == 1, ( + "row_parallel_embedding output Ids take 1 variable but got {}".format( + kwargs['Y@GRAD'] + ) ) X_var = main_block._var_recursive(kwargs['X'][0]) @@ -354,9 +354,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0]) Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0]) - assert not is_parameter_related( - X_var.name, main_block - ), f"left operand(X) [{X_var.name}] of dist matmul should not be parameter" + assert not is_parameter_related(X_var.name, main_block), ( + f"left operand(X) [{X_var.name}] of dist matmul should not be parameter" + ) X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) @@ -781,9 +781,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -817,9 +817,9 @@ def forward(ctx, *args, **kwargs): matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-2] - assert ( - matmul_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + assert matmul_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -1036,9 +1036,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -1072,9 +1072,9 @@ def forward(ctx, *args, **kwargs): matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-1] - assert ( - matmul_row_dim_mapping >= 0 - ), f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + assert matmul_row_dim_mapping >= 0, ( + f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -1474,9 +1474,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -1510,9 +1510,9 @@ def forward(ctx, *args, **kwargs): matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-2] - assert ( - matmul_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + assert matmul_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + ) # infer new var shape with op dist attr x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var) @@ -1723,9 +1723,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -1759,9 +1759,9 @@ def forward(ctx, *args, **kwargs): matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-1] - assert ( - matmul_row_dim_mapping >= 0 - ), f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + assert matmul_row_dim_mapping >= 0, ( + f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -2153,9 +2153,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -2183,9 +2183,9 @@ def forward(ctx, *args, **kwargs): matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-1] - assert ( - matmul_col_dim_mapping >= 0 - ), f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + assert matmul_col_dim_mapping >= 0, ( + f"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_col_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids @@ -2396,9 +2396,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.process_ids: @@ -2426,9 +2426,9 @@ def forward(ctx, *args, **kwargs): matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name )[-2] - assert ( - matmul_row_dim_mapping >= 0 - ), f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + assert matmul_row_dim_mapping >= 0, ( + f"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{matmul_row_dim_mapping}]" + ) process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_group = op_dist_attr.process_mesh.process_ids diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py b/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py index 9faa879c61e2b4..ca9217c892d321 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py @@ -44,13 +44,13 @@ def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - len(op_desc.input_arg_names()) == 1 - ), f"reduce_sum op [{op_desc.type}] has [{len(op_desc.input_arg_names())}] inputs" + assert len(op_desc.input_arg_names()) == 1, ( + f"reduce_sum op [{op_desc.type}] has [{len(op_desc.input_arg_names())}] inputs" + ) input_arg_name = op_desc.input_arg_names()[0] - assert ( - len(op_desc.output_arg_names()) == 1 - ), f"reduce_sum op [{op_desc.type}] has [{len(op_desc.output_arg_names())}] outputs" + assert len(op_desc.output_arg_names()) == 1, ( + f"reduce_sum op [{op_desc.type}] has [{len(op_desc.output_arg_names())}] outputs" + ) output_arg_name = op_desc.output_arg_names()[0] keep_dim = op_desc.attr('keep_dim') dims = op_desc.attr('dim') diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/static/operators/dist_reshape.py index 6a8a5caa808093..74d8f8fc96da37 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_reshape.py @@ -48,9 +48,9 @@ def __init__(self, op_type): def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - dist_op.serial_op.type == "reshape2" - ), f"{dist_op.serial_op.type} is not supported by dist reshape yet." + assert dist_op.serial_op.type == "reshape2", ( + f"{dist_op.serial_op.type} is not supported by dist reshape yet." + ) x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] @@ -293,9 +293,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): @@ -549,9 +549,9 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): @@ -798,9 +798,9 @@ def forward(ctx, *args, **kwargs): main_block = dist_op_context.work_block src_op = dist_op_context.cur_src_op op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert ( - op_dist_attr is not None - ), f"backward op [{src_op}] don't have dist attribute !" + assert op_dist_attr is not None, ( + f"backward op [{src_op}] don't have dist attribute !" + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_split.py b/python/paddle/distributed/auto_parallel/static/operators/dist_split.py index 25e3a776fe4d42..830dcace18bc81 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_split.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_split.py @@ -39,26 +39,26 @@ def update_dims_mapping(dist_op): op_desc = dist_op.serial_op.desc x_name = op_desc.input('X')[0] - assert ( - len(op_desc.input('AxisTensor')) == 0 - ), "Attribute AxisTensor is not supported by dist split." - assert ( - len(op_desc.input('SectionsTensorList')) == 0 - ), "Attribute SectionsTensorList is not supported by dist split." + assert len(op_desc.input('AxisTensor')) == 0, ( + "Attribute AxisTensor is not supported by dist split." + ) + assert len(op_desc.input('SectionsTensorList')) == 0, ( + "Attribute SectionsTensorList is not supported by dist split." + ) output_arg_names = op_desc.output('Out') num = op_desc.attr('num') sections = op_desc.attr('sections') if num: - assert (sections is None) or ( - len(sections) == 0 - ), f"Both Attributes of num: {num} and sections: {sections} are specified." + assert (sections is None) or (len(sections) == 0), ( + f"Both Attributes of num: {num} and sections: {sections} are specified." + ) first_attr = num rule_type = "split_with_num" else: - assert ( - not num - ), f"Both Attributes of num: {num} and sections: {sections} are specified." + assert not num, ( + f"Both Attributes of num: {num} and sections: {sections} are specified." + ) first_attr = sections rule_type = "split" axis = op_desc.attr('axis') diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_tile.py b/python/paddle/distributed/auto_parallel/static/operators/dist_tile.py index 45371797e16878..7eaf534e3f9038 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_tile.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_tile.py @@ -33,9 +33,9 @@ def __init__(self, op_type): def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - dist_op.serial_op.type == "tile" - ), f"{dist_op.serial_op.type} is not supported by dist transpose yet." + assert dist_op.serial_op.type == "tile", ( + f"{dist_op.serial_op.type} is not supported by dist transpose yet." + ) x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/static/operators/dist_transpose.py index 571415edf616ac..38f99d9deec80b 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_transpose.py @@ -47,9 +47,9 @@ def __init__(self, op_type): def update_dims_mapping(dist_op): # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) op_desc = dist_op.serial_op.desc - assert ( - dist_op.serial_op.type == "transpose2" - ), f"{dist_op.serial_op.type} is not supported by dist transpose yet." + assert dist_op.serial_op.type == "transpose2", ( + f"{dist_op.serial_op.type} is not supported by dist transpose yet." + ) x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/static/operators/dist_update_loss_scaling.py index 39d4fdfef974a7..9b2eefa50519f6 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_update_loss_scaling.py @@ -72,9 +72,9 @@ def backward(ctx, *args, **kwargs): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert ( - dist_attr is not None - ), f"backward op [{backward_op}] don't have dist attribute !" + assert dist_attr is not None, ( + f"backward op [{backward_op}] don't have dist attribute !" + ) assert rank_id in dist_attr.process_mesh.process_ids @@ -103,46 +103,46 @@ def backward(ctx, *args, **kwargs): 'OutBadSteps' ) - assert ( - len(kwargs['FoundInfinite']) == 1 - ), "update_loss_scaling input FoundInfinite take 1 variable but got {}".format( - kwargs['FoundInfinite'] + assert len(kwargs['FoundInfinite']) == 1, ( + "update_loss_scaling input FoundInfinite take 1 variable but got {}".format( + kwargs['FoundInfinite'] + ) ) - assert ( - len(kwargs['PrevLossScaling']) == 1 - ), "update_loss_scaling input PrevLossScaling take 1 variable but got {}".format( - kwargs['PrevLossScaling'] + assert len(kwargs['PrevLossScaling']) == 1, ( + "update_loss_scaling input PrevLossScaling take 1 variable but got {}".format( + kwargs['PrevLossScaling'] + ) ) - assert ( - len(kwargs['InGoodSteps']) == 1 - ), "update_loss_scaling input InGoodSteps take 1 variable but got {}".format( - kwargs['InGoodSteps'] + assert len(kwargs['InGoodSteps']) == 1, ( + "update_loss_scaling input InGoodSteps take 1 variable but got {}".format( + kwargs['InGoodSteps'] + ) ) - assert ( - len(kwargs['InBadSteps']) == 1 - ), "update_loss_scaling input InBadSteps take 1 variable but got {}".format( - kwargs['InBadSteps'] + assert len(kwargs['InBadSteps']) == 1, ( + "update_loss_scaling input InBadSteps take 1 variable but got {}".format( + kwargs['InBadSteps'] + ) ) - assert ( - len(kwargs['LossScaling']) == 1 - ), "update_loss_scaling output LossScaling take 1 variable but got {}".format( - kwargs['LossScaling'] + assert len(kwargs['LossScaling']) == 1, ( + "update_loss_scaling output LossScaling take 1 variable but got {}".format( + kwargs['LossScaling'] + ) ) - assert ( - len(kwargs['OutGoodSteps']) == 1 - ), "update_loss_scaling output OutGoodSteps take 1 variable but got {}".format( - kwargs['OutGoodSteps'] + assert len(kwargs['OutGoodSteps']) == 1, ( + "update_loss_scaling output OutGoodSteps take 1 variable but got {}".format( + kwargs['OutGoodSteps'] + ) ) - assert ( - len(kwargs['OutBadSteps']) == 1 - ), "update_loss_scaling output OutBadSteps take 1 variable but got {}".format( - kwargs['OutBadSteps'] + assert len(kwargs['OutBadSteps']) == 1, ( + "update_loss_scaling output OutBadSteps take 1 variable but got {}".format( + kwargs['OutBadSteps'] + ) ) - assert len(kwargs['X']) == len( - kwargs['Out'] - ), "update_loss_scaling got [{}] X and [{}] Out, which are supposed to be equal".format( - len(kwargs['X']), len(kwargs['Out']) + assert len(kwargs['X']) == len(kwargs['Out']), ( + "update_loss_scaling got [{}] X and [{}] Out, which are supposed to be equal".format( + len(kwargs['X']), len(kwargs['Out']) + ) ) filter_vars = [] diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer.py b/python/paddle/distributed/auto_parallel/static/parallelizer.py index 907faac4931bc2..27177fae849cea 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer.py @@ -307,9 +307,9 @@ def parallelize( if self._enable_auto_mapping and self._need_rank_mapping: # Do the mapping pass before parallelization - assert ( - self._cluster is not None - ), "The cluster must not be none when using auto mapping." + assert self._cluster is not None, ( + "The cluster must not be none when using auto mapping." + ) dist_programs = {} world_process_group = get_world_process_group() dist_context = None @@ -417,9 +417,9 @@ def parallelize( ] new_process = subprocess.Popen(new_cmd) new_process.wait() - assert ( - new_process.returncode == 0 - ), "Launch failed with rank mapping" + assert new_process.returncode == 0, ( + "Launch failed with rank mapping" + ) print("Successfully do the second launch for auto mapping!") sys.exit(0) else: diff --git a/python/paddle/distributed/auto_parallel/static/partitioner.py b/python/paddle/distributed/auto_parallel/static/partitioner.py index a6fae901e76c3c..ec25b69a256a40 100644 --- a/python/paddle/distributed/auto_parallel/static/partitioner.py +++ b/python/paddle/distributed/auto_parallel/static/partitioner.py @@ -142,12 +142,12 @@ def partition_startup_program( for op in serial_startup_program.global_block().ops: # TODO if var not belong to this rank, should be filtered output_vars = op.desc.output_arg_names() - assert ( - len(output_vars) == 1 - ), f"initializer should output only ONE variable, but got [{op.desc}]" - assert ( - temp_varname_map[output_vars[0]] in var2shape - ), f"try to initialize [{output_vars[0]}] which is not a persistable var" + assert len(output_vars) == 1, ( + f"initializer should output only ONE variable, but got [{op.desc}]" + ) + assert temp_varname_map[output_vars[0]] in var2shape, ( + f"try to initialize [{output_vars[0]}] which is not a persistable var" + ) new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op.desc) new_op_desc._rename_output( @@ -398,17 +398,17 @@ def _get_dist_shape(var, dist_attr): if mapping == []: return var_shape - assert len(var_shape) == len( - mapping - ), f"variable shape [{var_shape}] and dim_mapping [{mapping}] is NOT match !" + assert len(var_shape) == len(mapping), ( + f"variable shape [{var_shape}] and dim_mapping [{mapping}] is NOT match !" + ) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: new_shape.append(var_shape[idx]) else: - assert ( - var_shape[idx] % mesh[mapping[idx]] == 0 - ), f"un-event partition: var_shape[idx]=[{var_shape[idx]}], mesh[{mesh[mapping[idx]]}], {var.name}, {var_shape}, {mesh}, {mapping}" + assert var_shape[idx] % mesh[mapping[idx]] == 0, ( + f"un-event partition: var_shape[idx]=[{var_shape[idx]}], mesh[{mesh[mapping[idx]]}], {var.name}, {var_shape}, {mesh}, {mapping}" + ) new_shape.append(var_shape[idx] // mesh[mapping[idx]]) return new_shape diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index c5517dd72040ba..5317f28aca1f39 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -86,16 +86,16 @@ def reshard_single_value(program, op, operand, attr): def reshard_combine_value(program, op, operand, attr): prev_var = operand.source() - assert ( - prev_var.get_defining_op().name() == 'builtin.combine' - ), f"TensorList must be defined by builtin.combine op, but is {prev_var.get_defining_op().name()}." + assert prev_var.get_defining_op().name() == 'builtin.combine', ( + f"TensorList must be defined by builtin.combine op, but is {prev_var.get_defining_op().name()}." + ) combine_op = prev_var.get_defining_op() array_attr = attr.as_array_attr() - assert len(combine_op.operands()) == len( - array_attr - ), "The number of combine op operands and the number of dist array_attr are not equal in op" + assert len(combine_op.operands()) == len(array_attr), ( + "The number of combine op operands and the number of dist array_attr are not equal in op" + ) reshard_vars = [] for inner_operand, inner_attr in zip(combine_op.operands(), array_attr): @@ -121,12 +121,12 @@ def apply_partition_pass(program, block=None): if op.name() in partition_skip_op_list: continue - assert len(op.operands()) == len( - op.dist_attr.operands() - ), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}" - assert len(op.results()) == len( - op.dist_attr.results() - ), f"The number of results and the number of op_dist_attr's results are not equal in op: {op}" + assert len(op.operands()) == len(op.dist_attr.operands()), ( + f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}" + ) + assert len(op.results()) == len(op.dist_attr.results()), ( + f"The number of results and the number of op_dist_attr's results are not equal in op: {op}" + ) # deal with inplace value for out_idx, in_idx in paddle.core.pir.get_op_inplace_info(op).items(): @@ -142,9 +142,9 @@ def apply_partition_pass(program, block=None): ): continue - assert ( - not prev_var.is_combine() - ), f"The current partition pass not support inplace value of {op} is tensor list." + assert not prev_var.is_combine(), ( + f"The current partition pass not support inplace value of {op} is tensor list." + ) operand_attr = operand_attr.as_tensor_dist_attr() @@ -156,9 +156,9 @@ def apply_partition_pass(program, block=None): result = op.result(out_idx) result_attr = op.dist_attr.result(out_idx).as_tensor_dist_attr() - assert ( - operand_attr == result_attr - ), f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of {op}" + assert operand_attr == result_attr, ( + f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of {op}" + ) # reshard output paddle.pir.set_insertion_point_after(op) @@ -245,9 +245,13 @@ def decompose_reshard_pass(dist_program): # split the reshard compose p2p and collective into one p2p reshard and one collective reshard. # avoid global to sub mesh case if ( - input.dist_attr().process_mesh - != result.dist_attr().process_mesh - ) and input.dist_attr().process_mesh.ndim == result.dist_attr().process_mesh.ndim: + ( + input.dist_attr().process_mesh + != result.dist_attr().process_mesh + ) + and input.dist_attr().process_mesh.ndim + == result.dist_attr().process_mesh.ndim + ): if ( input.dist_attr().placements != result.dist_attr().placements @@ -321,7 +325,9 @@ def reshard_op_pass(dist_program, global_params_grads=None, block=None): assert ( not var.initialized() or var.dist_attr() == src_dist_attr - ), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}" + ), ( + f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}" + ) if src_dist_attr == dst_dist_attr: op.result(0).replace_all_uses_with(var) @@ -358,9 +364,9 @@ def reshard_op_pass(dist_program, global_params_grads=None, block=None): reshard_func = choose_reshard_func( src_dist_attr, dst_dist_attr ) - assert ( - reshard_func is not None - ), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}, {var.get_defining_op()}' + assert reshard_func is not None, ( + f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}, {var.get_defining_op()}' + ) with pir_op_role_guard(ref_op_role): out_value = reshard_func.reshard( @@ -407,9 +413,9 @@ def replace_moe_sub_mesh_tensors(op): for idx, val in enumerate(op.results()): val_mesh = val.dist_attr().process_mesh if cur_rank in val_mesh.process_ids: - assert ( - out_value is None - ), f'{op} has more than one results on rank {cur_rank}' + assert out_value is None, ( + f'{op} has more than one results on rank {cur_rank}' + ) out_value = val out_idx = idx @@ -522,9 +528,9 @@ def prune_op(block): ): op.erase() elif op.name() == "dist_op.reshard": - assert op.result( - 0 - ).use_empty(), f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' + assert op.result(0).use_empty(), ( + f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' + ) op.erase() prune_op(dist_program.global_block()) @@ -673,9 +679,9 @@ def replace_moe_global_mesh_tensor(op): val_mesh = val.dist_attr().process_mesh if cur_rank not in val_mesh.process_ids: continue - assert ( - in_value is None - ), f'{op} has more than one inputs on rank {cur_rank}' + assert in_value is None, ( + f'{op} has more than one inputs on rank {cur_rank}' + ) in_value = val in_idx = idx @@ -766,9 +772,9 @@ def eliminate_transpose_by_reshape(program): def complete_op_role(main_program, op_role_scope: list): - assert ( - len(op_role_scope) == 3 and len(op_role_scope[0]) == 2 - ), "op_role_scope should has the shape[3, 2]" + assert len(op_role_scope) == 3 and len(op_role_scope[0]) == 2, ( + "op_role_scope should has the shape[3, 2]" + ) forward_op_start = op_role_scope[0][0] forward_op_end = op_role_scope[0][1] @@ -810,7 +816,9 @@ def pipeline_pass(dense_main_program, dense_startup_program, pipeline_strategy): "FThenB", "1F1B", "VPP", - ], f"pipeline scheduler only support FThenB, 1F1B and VPP now, but receive {pass_name}" + ], ( + f"pipeline scheduler only support FThenB, 1F1B and VPP now, but receive {pass_name}" + ) pass_attr = {} pass_attr["num_micro_batches"] = pipeline_strategy.accumulate_steps @@ -1159,9 +1167,9 @@ def complete_chunk_id(dist_program, startup_program, pipeline_strategy): pp_stage_layer_nums = [0] * pp_degree for i in stage_ids: pp_stage_layer_nums[i] = pp_stage_layer_nums[i] + 1 - assert all( - value >= vpp_degree for value in pp_stage_layer_nums - ), "The number of layers on each pp_stage must not be less than the vpp_degree in the pp_stage to ensure that each chunk contains at least one layer." + assert all(value >= vpp_degree for value in pp_stage_layer_nums), ( + "The number of layers on each pp_stage must not be less than the vpp_degree in the pp_stage to ensure that each chunk contains at least one layer." + ) seg_layer_num = [0] * num_chunks for pp_stage in range( @@ -1855,9 +1863,9 @@ def fuse_attention_ffn_qkv_pass( with paddle.base.dygraph.guard(): dyparam_dtype = concated_dy_param_list[0].dtype for param in concated_dy_param_list: - assert ( - dyparam_dtype == param.dtype - ), "The dtypes of dy parameters to be fused are not the same." + assert dyparam_dtype == param.dtype, ( + "The dtypes of dy parameters to be fused are not the same." + ) dtensor = paddle.zeros( shape=name2pir_param_map[pir_param].shape, diff --git a/python/paddle/distributed/auto_parallel/static/planner.py b/python/paddle/distributed/auto_parallel/static/planner.py index eaa8db218dd3cf..c6a9148ebce4de 100755 --- a/python/paddle/distributed/auto_parallel/static/planner.py +++ b/python/paddle/distributed/auto_parallel/static/planner.py @@ -159,9 +159,9 @@ def _enum_dims_mapping( @staticmethod def enum_process_mesh_topology(processes): """Enumerate all process meshes with the given processes.""" - assert ( - processes >= 1 - ), "The processes must be number and greater than 0." + assert processes >= 1, ( + "The processes must be number and greater than 0." + ) # compute divisors divisors = [] for i in range(1, processes + 1): @@ -352,8 +352,7 @@ def enum_valid_dist_attr_for_program( auto.ProcessMesh( mesh=np.array( global_group[ - i - * per_process_mesh_group : (i + 1) + i * per_process_mesh_group : (i + 1) * per_process_mesh_group ] ) @@ -418,9 +417,9 @@ def enum_valid_dist_attr_for_program( program, op, op_process_mesh ) - assert ( - op_valid_dist_attrs is not None - ), f"Enumerate {op} valid distributed attribute failed." + assert op_valid_dist_attrs is not None, ( + f"Enumerate {op} valid distributed attribute failed." + ) valid_dist_attr_dict[op.desc.id()] = [ op_valid_dist_attrs, pipeline_stage, @@ -645,9 +644,9 @@ def set_tensor_dist_attr(self, op, op_dist_attr, vars, dist_context): ) def change_process_mesh(self, op, changed_process_mesh, vars, dist_context): - dist_context.get_op_dist_attr_for_program(op).process_mesh = ( - changed_process_mesh - ) + dist_context.get_op_dist_attr_for_program( + op + ).process_mesh = changed_process_mesh for var_name in op.output_arg_names: dist_context.get_tensor_dist_attr_for_program( vars[var_name] @@ -748,9 +747,9 @@ def search_once( ) # change the selected op stage and output dist attr - new_valid_dist_attr_dict[selected_op.desc.id()][ - 1 - ] = changed_stage + new_valid_dist_attr_dict[selected_op.desc.id()][1] = ( + changed_stage + ) new_process_mesh = pipeline_process_meshes[changed_stage] selected_op_dist_attr.process_mesh = new_process_mesh for op_dist_attr in new_valid_dist_attr_dict[ @@ -778,9 +777,9 @@ def search_once( changed_stage ] if stage == changed_stage + 1: - new_valid_dist_attr_dict[ops[idx].desc.id()][ - 1 - ] = changed_stage + new_valid_dist_attr_dict[ops[idx].desc.id()][1] = ( + changed_stage + ) for op_dist_attr in valid_dist_attr_list: op_dist_attr.process_mesh = new_process_mesh new_dist_context.get_op_dist_attr_for_program( @@ -843,9 +842,9 @@ def search_once( ) # change the selected op stage and output tensor dist attr - new_valid_dist_attr_dict[selected_op.desc.id()][ - 1 - ] = changed_stage + new_valid_dist_attr_dict[selected_op.desc.id()][1] = ( + changed_stage + ) new_process_mesh = pipeline_process_meshes[changed_stage] selected_op_dist_attr.process_mesh = new_process_mesh for op_dist_attr in new_valid_dist_attr_dict[ @@ -872,9 +871,9 @@ def search_once( changed_stage ] if stage == changed_stage - 1: - new_valid_dist_attr_dict[ops[idx].desc.id()][ - 1 - ] = changed_stage + new_valid_dist_attr_dict[ops[idx].desc.id()][1] = ( + changed_stage + ) for op_dist_attr in valid_dist_attr_list: op_dist_attr.process_mesh = new_process_mesh diff --git a/python/paddle/distributed/auto_parallel/static/process_group.py b/python/paddle/distributed/auto_parallel/static/process_group.py index 085a0c813988d1..8e7e682ec367d1 100644 --- a/python/paddle/distributed/auto_parallel/static/process_group.py +++ b/python/paddle/distributed/auto_parallel/static/process_group.py @@ -89,9 +89,9 @@ def new_process_group( class ProcessGroup: def __init__(self, group_id, ranks, group_type=None): if group_id == 0 and get_process_group(0) is not None: - assert ( - group_id != 0 - ), "Process group id 0 is reserved for all ranks." + assert group_id != 0, ( + "Process group id 0 is reserved for all ranks." + ) self._group_id = group_id self._ranks = ranks # Add the current ranks into group 0 @@ -121,9 +121,9 @@ def add_ranks(self, new_ranks): if set(new_ranks) <= set(self.ranks): return else: - assert ( - not self.is_instantiate() - ), "Cannot add new ranks after instantiating the process group" + assert not self.is_instantiate(), ( + "Cannot add new ranks after instantiating the process group" + ) self._ranks.extend(new_ranks) self._ranks = list(set(self.ranks)) diff --git a/python/paddle/distributed/auto_parallel/static/process_mesh_v2.py b/python/paddle/distributed/auto_parallel/static/process_mesh_v2.py index 7a58f12836b432..d055328ed7ad8d 100644 --- a/python/paddle/distributed/auto_parallel/static/process_mesh_v2.py +++ b/python/paddle/distributed/auto_parallel/static/process_mesh_v2.py @@ -56,21 +56,21 @@ def __init__(self, mesh, dim_names=None): self._shape = list(self._mesh.shape) self._process_ids = self._mesh.flatten().tolist() - assert all( - isinstance(p, int) for p in self._process_ids - ), "All elements of the mesh must be integer" - assert ( - min(self._process_ids) >= 0 - ), 'All elements of the mesh must be >= 0.' + assert all(isinstance(p, int) for p in self._process_ids), ( + "All elements of the mesh must be integer" + ) + assert min(self._process_ids) >= 0, ( + 'All elements of the mesh must be >= 0.' + ) unique_process_ids = set(self._process_ids) - assert len(unique_process_ids) == len( - self._process_ids - ), 'All elements of the mesh must be unique.' + assert len(unique_process_ids) == len(self._process_ids), ( + 'All elements of the mesh must be unique.' + ) if dim_names is not None: - assert len(dim_names) == len( - self._shape - ), "The length of dims_names must be same as the shape of the mesh." + assert len(dim_names) == len(self._shape), ( + "The length of dims_names must be same as the shape of the mesh." + ) self._dim_names = dim_names else: self._dim_names = ["d" + str(i) for i in range(len(self._shape))] diff --git a/python/paddle/distributed/auto_parallel/static/reshard.py b/python/paddle/distributed/auto_parallel/static/reshard.py index f29840fe6736e7..c9e4fd017635c7 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard.py +++ b/python/paddle/distributed/auto_parallel/static/reshard.py @@ -1096,9 +1096,9 @@ def __init__( "The type of auto_parallel_startup_prog should be Program or None, " f"but got {type(auto_parallel_startup_prog)}." ) - assert isinstance( - rank_id, int - ), f"The type of rank_id should be int, but got {type(rank_id)}." + assert isinstance(rank_id, int), ( + f"The type of rank_id should be int, but got {type(rank_id)}." + ) assert isinstance(dist_context, DistributedContext), ( "The type of dist_context should be DistributedContext, " f"but got {type(dist_context)}." @@ -1631,9 +1631,9 @@ def find_op_desc_seq( has_used = [False for x in has_used] to_send_process = process_list[0] has_used[0] = True - assert ( - to_send_process is not None - ), "Failed to find the send process." + assert to_send_process is not None, ( + "Failed to find the send process." + ) if to_send_process not in op_desc_seq.keys(): op_desc_seq[to_send_process] = [] @@ -1904,9 +1904,9 @@ def parse_op_desc( if op.desc.id == reshard_op.desc.id: idx = index break - assert ( - idx is not None - ), f"The op for reshard cannot be found in the rank {self.rank_id} program." + assert idx is not None, ( + f"The op for reshard cannot be found in the rank {self.rank_id} program." + ) src_name = src_tensor.name @@ -2012,9 +2012,9 @@ def is_grad(name): for var_name in item[1] ] break - assert ( - tensor_list - ), "The result of parsing allgather op should not be None." + assert tensor_list, ( + "The result of parsing allgather op should not be None." + ) elif isinstance(op_desc, SendOpDesc): if src_name not in self.has_sent.keys(): @@ -2154,9 +2154,9 @@ def is_grad(name): ) tensor_list.append(reset_lod_out) idx += 2 - self.has_recv[src_name][ - op_desc.src - ] = reset_lod_out + self.has_recv[src_name][op_desc.src] = ( + reset_lod_out + ) set_lod = True break if set_lod: @@ -2461,9 +2461,9 @@ def get_op_input_attrs(self, op, var_name): else: op_input_attrs = self._get_common_op_input_attrs(op, var_name) - assert ( - op_input_attrs - ), f"The input '{op.name}' of op '{var_name}' has no distributed attributes in subblock" + assert op_input_attrs, ( + f"The input '{op.name}' of op '{var_name}' has no distributed attributes in subblock" + ) return op_input_attrs @@ -2874,11 +2874,7 @@ def _is_special_op(op): -1 ) != len( dist_tensor.dist_attr.dims_mapping - ) or output_attr[ - 1 - ].count( - -1 - ) != len( + ) or output_attr[1].count(-1) != len( output_attr[1] ): raise ValueError( diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/nd_mesh_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/nd_mesh_reshard_func.py index b7950f7c82f146..60b818638d03af 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/nd_mesh_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/nd_mesh_reshard_func.py @@ -357,9 +357,9 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): ) nd_mesh_func = NdMeshReshardFunction() - assert nd_mesh_func.is_suitable( - tmp_dist_attr, dst_dist_attr - ), f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + assert nd_mesh_func.is_suitable(tmp_dist_attr, dst_dist_attr), ( + f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + ) return nd_mesh_func.reshard( tmp_dist_attr, dst_dist_attr, src_value, dst_type ) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py index a5f7d0089e2842..8f4194d98f105b 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py @@ -105,9 +105,9 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): ) p_to_r_func = PToRReshardFunction() - assert p_to_r_func.is_suitable( - tmp_dist_attr, dst_dist_attr - ), f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + assert p_to_r_func.is_suitable(tmp_dist_attr, dst_dist_attr), ( + f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + ) return p_to_r_func.reshard( tmp_dist_attr, dst_dist_attr, src_value, dst_type ) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_s_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_s_reshard_func.py index e2a3bb6dd61c7d..ed50a016f0b4ea 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_s_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_s_reshard_func.py @@ -47,9 +47,9 @@ def is_suitable(self, src_dist_attr, dst_dist_attr): def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): src_mesh = src_dist_attr.process_mesh src_reduce_type = src_dist_attr.partial_status[0] - assert ( - src_reduce_type == paddle.base.core.ReduceType.kRedSum - ), f"The p to s reshard func only support sum op, but received {src_reduce_type}" + assert src_reduce_type == paddle.base.core.ReduceType.kRedSum, ( + f"The p to s reshard func only support sum op, but received {src_reduce_type}" + ) chunk_id = -1 if src_value.get_defining_op().dist_attr: diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py index 44e78cb5e84a12..2bca9cac7be832 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py @@ -133,9 +133,9 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): curr_global_rank = paddle.distributed.get_rank() if curr_global_rank in dst_dist_attr.process_mesh.process_ids: r_to_s_func = RToSReshardFunction() - assert r_to_s_func.is_suitable( - tmp_dist_attr, dst_dist_attr - ), f"Invoke the r to s reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + assert r_to_s_func.is_suitable(tmp_dist_attr, dst_dist_attr), ( + f"Invoke the r to s reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + ) return r_to_s_func.reshard( tmp_dist_attr, dst_dist_attr, out_value, dst_type ) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py index a25d735d90bb7a..73b42f5199ba72 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py @@ -355,9 +355,9 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): ) s_to_r_func = SToRReshardFunction() - assert s_to_r_func.is_suitable( - tmp_dist_attr, dst_dist_attr - ), f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + assert s_to_r_func.is_suitable(tmp_dist_attr, dst_dist_attr), ( + f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + ) return s_to_r_func.reshard( tmp_dist_attr, dst_dist_attr, out_value, dst_type ) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py index 71a38e63d14ef5..47d7a2b5dda6b7 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py @@ -123,9 +123,9 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): if var.dist_attr().process_mesh == dst_mesh: chunk_id = find_var_used_op_chunk_id(var) - assert ( - -1 not in dst_type.shape - ), "dynamic shape is not supported by pir-auto parallel yet." + assert -1 not in dst_type.shape, ( + "dynamic shape is not supported by pir-auto parallel yet." + ) comm_group = new_process_group([src, dst], group_type="p2p") recv_value = paddle._C_ops.recv_v2( diff --git a/python/paddle/distributed/auto_parallel/static/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/static/tuner/algorithms.py index fcaa325c9ab994..8df82e5c0e3cc9 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/algorithms.py @@ -119,9 +119,9 @@ def _init_spaces(self): stage_range = self._config.sharding.get("tuning_range", None) if stage_range: - assert set(stage_range).issubset( - {0, 1, 2, 3} - ), f"Sharding Stage should belong into range within 0 - 3 but got {stage_range}." + assert set(stage_range).issubset({0, 1, 2, 3}), ( + f"Sharding Stage should belong into range within 0 - 3 but got {stage_range}." + ) stage_range.sort(reverse=True) else: stage_range = list(range(self._max_stage + 1)).sort(reverse=True) diff --git a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py index 24a60d1b2cc786..7c38e134a7cd48 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py @@ -85,9 +85,9 @@ def parse_process_groups(): def get_metric(results): - assert isinstance( - results, dict - ), f"results should be type of dictionary, but got {type(results)}." + assert isinstance(results, dict), ( + f"results should be type of dictionary, but got {type(results)}." + ) if 'Throughput' in results and isinstance(results['Throughput'], float): return float(results['Throughput']) else: diff --git a/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py index 077d243fa2a0e8..53107957a8950c 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/rule_based_tuner.py @@ -511,9 +511,9 @@ def convert_to_graph(block): else: var_node.attrs["type"] = "var" graph.attrs["var_to_id"][var_name] = var_node.id - graph.attrs["id_to_var_desc_id"][ - var_node.id - ] = var.desc.original_id() + graph.attrs["id_to_var_desc_id"][var_node.id] = ( + var.desc.original_id() + ) graph.attrs["id_to_var_name"][var_node.id] = var_name else: var_node_id = graph.attrs["var_to_id"][var_name] @@ -539,12 +539,12 @@ def convert_to_graph(block): else: var_node.attrs["type"] = "var" graph.attrs["var_to_id"][var_name] = var_node.id - graph.attrs["id_to_var_desc_id"][ - var_node.id - ] = var.desc.original_id() - graph.attrs["id_to_var_name"][ - var_node.id - ] = var_name + graph.attrs["id_to_var_desc_id"][var_node.id] = ( + var.desc.original_id() + ) + graph.attrs["id_to_var_name"][var_node.id] = ( + var_name + ) else: var_node_id = graph.attrs["var_to_id"][var_name] var_node = graph._nodes[var_node_id] @@ -1176,9 +1176,7 @@ def gen_full_program(self): self.op_original_id_to_op[op.desc.original_id()] = op self.op_original_id_to_idx[op.desc.original_id()] = idx - grad_op_id_to_op_id = ( - self.full_main_program_dist_context.dist_op_context.grad_op_id_to_op_id - ) + grad_op_id_to_op_id = self.full_main_program_dist_context.dist_op_context.grad_op_id_to_op_id for grad_op_original_id in grad_op_id_to_op_id: op_id = grad_op_id_to_op_id[grad_op_original_id] @@ -1408,9 +1406,9 @@ def _complete_sub_fwd_program(self, idx, sub_fwd_program, process_mesh): if parallelism not in self.sub_programs_dist_context[idx]: self.sub_programs_dist_context[idx][parallelism] = {} key = self.convert_process_mesh_to_key(process_mesh) - self.sub_programs_dist_context[idx][parallelism][ - key - ] = dist_context + self.sub_programs_dist_context[idx][parallelism][key] = ( + dist_context + ) else: self._logger.info( f"No pattern has be matched under {parallelism} parallelism when sub program is {sub_fwd_program}." @@ -1534,9 +1532,9 @@ def _is_grad_var_name(name): ref_dims_mapping = ( fwd_op_dist_attr.get_output_dims_mapping(input_name) ) - assert ( - ref_dims_mapping is not None - ), f"[{input_name}] 's dims mapping is NONE" + assert ref_dims_mapping is not None, ( + f"[{input_name}] 's dims mapping is NONE" + ) grad_op_dist_attr.set_input_dims_mapping( input_name, ref_dims_mapping ) @@ -1574,9 +1572,9 @@ def _is_grad_var_name(name): map(_is_grad_var_name, grad_op_next_op.input_arg_names) ) output_name = grad_op_next_op.output_arg_names[0] - assert ( - output_name in grad_var_to_var - ), f"sum op's output '{output_name}' has no corresponding var" + assert output_name in grad_var_to_var, ( + f"sum op's output '{output_name}' has no corresponding var" + ) ref_fwd_var_name = grad_var_to_var[output_name] ref_fwd_var = vars[ref_fwd_var_name] ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( @@ -1756,12 +1754,12 @@ def _complete_sub_update_program(self, sub_program_dist_context): continue if "Grad" in op.input_names and "Param" in ops[idx].input_names: - assert ( - len(op.input("Param")) == 1 - ), "Only support one-to-one now." - assert ( - len(op.input("Grad")) == 1 - ), "Only support one-to-one now." + assert len(op.input("Param")) == 1, ( + "Only support one-to-one now." + ) + assert len(op.input("Grad")) == 1, ( + "Only support one-to-one now." + ) param = vars[op.input("Param")[0]] grad_var = vars[op.input("Grad")[0]] if param.desc.original_id() in dist_tensors: @@ -1968,20 +1966,18 @@ def _local_stage_pass(self, start, end, process_mesh): 1 ] = self.stage_best_cost_of_pm[start][end][key][ "dist_context" - ][ - 0 - ] - self.stage_best_cost_of_pm[start][end][key]["cost"][ - 0 - ] = cost + ][0] + self.stage_best_cost_of_pm[start][end][key]["cost"][0] = ( + cost + ) self.stage_best_cost_of_pm[start][end][key]["dist_context"][ 0 ] = dist_context elif index == 1: - self.stage_best_cost_of_pm[start][end][key]["cost"][ - 1 - ] = cost + self.stage_best_cost_of_pm[start][end][key]["cost"][1] = ( + cost + ) self.stage_best_cost_of_pm[start][end][key]["dist_context"][ 1 ] = dist_context @@ -2045,9 +2041,9 @@ def local_stage_pass(self, start, end, device_mesh): best_cost = self.stage_best_cost_of_pm[start][end][key][ "best_cost" ] - self.stage_best_cost_of_dm[start][end][dm_key][ - "cost" - ] = best_cost + self.stage_best_cost_of_dm[start][end][dm_key]["cost"] = ( + best_cost + ) self.stage_best_cost_of_dm[start][end][dm_key][ "dist_context" ] = self.stage_best_cost_of_pm[start][end][key][ @@ -2103,12 +2099,12 @@ def get_best_process_mesh(self, start, end, device_mesh): ) if cost < best_cost: best_cost = cost - self.stage_best_cost_of_dm[start][end][dm_key][ - "cost" - ] = cost - self.stage_best_cost_of_dm[start][end][dm_key][ - "memory" - ] = local_stage_memory + self.stage_best_cost_of_dm[start][end][dm_key]["cost"] = ( + cost + ) + self.stage_best_cost_of_dm[start][end][dm_key]["memory"] = ( + local_stage_memory + ) self.stage_best_cost_of_dm[start][end][dm_key][ "dist_context" ] = dist_context @@ -2156,12 +2152,12 @@ def local_stage_pass_new(self, start, end, device_mesh): if (start <= 1 and end <= 2) or end == len(self.layers) - 1: cost, local_stage_memory = self._get_sub_program_cost(dist_context) self.stage_best_cost_of_dm[start][end][dm_key]["cost"] = cost - self.stage_best_cost_of_dm[start][end][dm_key][ - "memory" - ] = local_stage_memory - self.stage_best_cost_of_dm[start][end][dm_key][ - "dist_context" - ] = dist_context + self.stage_best_cost_of_dm[start][end][dm_key]["memory"] = ( + local_stage_memory + ) + self.stage_best_cost_of_dm[start][end][dm_key]["dist_context"] = ( + dist_context + ) # some cache is used to speed up because the layer 1~end is same, for example: # stage_best_cost_of_dm[0][2] = stage_best_cost_of_dm[0][1] + stage_best_cost_of_dm[0][1] - stage_best_cost_of_pm[0][0] @@ -2180,9 +2176,9 @@ def local_stage_pass_new(self, start, end, device_mesh): end - 1 ][dm_key]["memory"] self.stage_best_cost_of_dm[start][end][dm_key]["cost"] = cost - self.stage_best_cost_of_dm[start][end][dm_key][ - "memory" - ] = local_stage_memory + self.stage_best_cost_of_dm[start][end][dm_key]["memory"] = ( + local_stage_memory + ) self.stage_best_cost_of_dm[start][end][dm_key][ "dist_context" ] = dist_context @@ -2207,9 +2203,9 @@ def local_stage_pass_new(self, start, end, device_mesh): local_stage_memory_former_1 - local_stage_memory_former_2 ) self.stage_best_cost_of_dm[start][end][dm_key]["cost"] = cost - self.stage_best_cost_of_dm[start][end][dm_key][ - "memory" - ] = local_stage_memory + self.stage_best_cost_of_dm[start][end][dm_key]["memory"] = ( + local_stage_memory + ) self.stage_best_cost_of_dm[start][end][dm_key][ "dist_context" ] = dist_context @@ -2672,9 +2668,9 @@ def save_strategy(self, best_dist_context, path): for key in best_dist_context._dist_tensors_for_program: if key in self._dist_context._dist_tensors_for_program: dist_tensor = best_dist_context._dist_tensors_for_program[key] - dist_attrs["tensor"][ - key - ] = dist_tensor.dist_attr.serialize_to_string() + dist_attrs["tensor"][key] = ( + dist_tensor.dist_attr.serialize_to_string() + ) assert dist_attrs["tensor"], "Tensor dist attrs must not be None." for key in best_dist_context._dist_ops_for_program: @@ -2756,9 +2752,9 @@ def tune(self): else: best_dist_context = self.tune_o1() - assert ( - best_dist_context is not None - ), "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding." + assert best_dist_context is not None, ( + "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding." + ) for key in best_dist_context._dist_tensors_for_program: if key in self._dist_context._dist_tensors_for_program: diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 9cb8734720d777..52d8f61fad57cd 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -183,12 +183,12 @@ def compute_compatible_dims_mapping(dims_mapping_list): return None length = len(dims_mapping_list[0]) for dims_mapping in dims_mapping_list: - assert ( - dims_mapping is not None - ), "Dims mapping must not be None for compatible computation" - assert ( - len(dims_mapping) == length - ), "The length of dims_mapping in list must be same for compatible computation." + assert dims_mapping is not None, ( + "Dims mapping must not be None for compatible computation" + ) + assert len(dims_mapping) == length, ( + "The length of dims_mapping in list must be same for compatible computation." + ) compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( @@ -252,9 +252,9 @@ def check_distributed_attr_for_program(program, dist_context=None): if dist_context is None: dist_context = get_default_distributed_context() - assert ( - dist_context.is_initialized_for_program() - ), "Distributed attributes must be initialized before check." + assert dist_context.is_initialized_for_program(), ( + "Distributed attributes must be initialized before check." + ) for block in program.blocks: for tensor in block.vars.values(): dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) @@ -309,9 +309,9 @@ def _get_comm_group(processes, shape, axis, rank): # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous - assert ( - rank in processes - ), f"rank [{rank}] is NOT in processes group {processes}" + assert rank in processes, ( + f"rank [{rank}] is NOT in processes group {processes}" + ) rank_relative = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relative) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] @@ -377,16 +377,16 @@ def _coordinate2linear_idx(mesh_shape, coordinate): # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} # if you want a more general mapping, you should use cartesian product - assert len(mesh_shape) == len( - coordinate - ), f"coordinate should have the same size as mesh shape, but got shape: {mesh_shape}, coordinate: {coordinate}" + assert len(mesh_shape) == len(coordinate), ( + f"coordinate should have the same size as mesh shape, but got shape: {mesh_shape}, coordinate: {coordinate}" + ) for i in range(len(mesh_shape)): - assert ( - coordinate[i] >= 0 - ), f"index in dimension [{i}] is least than zero. coordinate: {coordinate}" - assert ( - coordinate[i] < mesh_shape[i] - ), f"index beyond extent in dimension [{i}]. shape: {mesh_shape}, coordinate: {coordinate}" + assert coordinate[i] >= 0, ( + f"index in dimension [{i}] is least than zero. coordinate: {coordinate}" + ) + assert coordinate[i] < mesh_shape[i], ( + f"index beyond extent in dimension [{i}]. shape: {mesh_shape}, coordinate: {coordinate}" + ) base = mesh_shape[-1] linear_idx = coordinate[-1] @@ -419,9 +419,9 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): """ assert linear_idx >= 0, f"linear index [{linear_idx}] is least than zero" - assert linear_idx < np.prod( - mesh_shape - ), f"linear index beyond the extent of mesh shape. shape: {mesh_shape}, linear index: {linear_idx}" + assert linear_idx < np.prod(mesh_shape), ( + f"linear index beyond the extent of mesh shape. shape: {mesh_shape}, linear index: {linear_idx}" + ) base = 1 coordinate = [-1] * len(mesh_shape) @@ -462,9 +462,9 @@ def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping mesh = dist_attr.process_mesh.shape - assert len(var_shape) == len( - mapping - ), f"variable shape [{var_shape}] and dim_mapping [{mapping}] is NOT match !" + assert len(var_shape) == len(mapping), ( + f"variable shape [{var_shape}] and dim_mapping [{mapping}] is NOT match !" + ) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: @@ -689,9 +689,9 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): ... ] >>> param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) """ - assert _check_valid_path( - checkpoint_path - ), "'checkpoint_path' cannot be None." + assert _check_valid_path(checkpoint_path), ( + "'checkpoint_path' cannot be None." + ) assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." state_dict_info = _load_distributed_state_dict(checkpoint_path) @@ -739,9 +739,9 @@ def load_checkpoint_into_program( from .dist_context import get_default_distributed_context assert isinstance(program, paddle.static.Program) - assert _check_valid_path( - checkpoint_path - ), "'checkpoint_path' cannot be None." + assert _check_valid_path(checkpoint_path), ( + "'checkpoint_path' cannot be None." + ) assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." if dist_context is None: dist_context = get_default_distributed_context() @@ -794,9 +794,9 @@ def _load_distributed_attribute(dist_attr_path): for dist_attr_file in dist_attr_path: dist_attr = paddle.load(dist_attr_file) pre_world_size = dist_attr["world_size"] - assert pre_world_size == len( - dist_attr_path - ), "The number of 'dist_attr_path' must be equal to the last training world size." + assert pre_world_size == len(dist_attr_path), ( + "The number of 'dist_attr_path' must be equal to the last training world size." + ) for name, attr in dist_attr["model"].items(): if name not in total_dist_attr: total_dist_attr[name] = attr @@ -825,9 +825,9 @@ def _load_distributed_state_dict(checkpoint_path): for idx, ckpt_file in enumerate(checkpoint_path): state_dict_info = paddle.load(ckpt_file, return_numpy=True) pre_world_size = state_dict_info["world_size"] - assert pre_world_size == len( - checkpoint_path - ), "The number of 'checkpoint_path' must be equal to the last training world size." + assert pre_world_size == len(checkpoint_path), ( + "The number of 'checkpoint_path' must be equal to the last training world size." + ) if idx == 0: addition_info = state_dict_info["addition_info"] for name, value in state_dict_info["model"].items(): @@ -909,9 +909,9 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): dist_param_dict(dict): parameters' value of current rank. """ assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." - assert isinstance( - dist_param_dict, dict - ), f"The type of 'dist_param_dict' should be 'dict', but got {type(dist_param_dict)}." + assert isinstance(dist_param_dict, dict), ( + f"The type of 'dist_param_dict' should be 'dict', but got {type(dist_param_dict)}." + ) for name, value in dist_param_dict.items(): if not isinstance(name, str): raise TypeError( @@ -1010,9 +1010,9 @@ def _merge_parameter_with_dist_attr(param_list, dist_attr): complete_shape, ) - assert ( - len(partition_param_list) == 1 or not partition_param_list - ), "Fail to merge parameter" + assert len(partition_param_list) == 1 or not partition_param_list, ( + "Fail to merge parameter" + ) complete_param = partition_param_list[0][0] return complete_param @@ -1356,9 +1356,9 @@ def get_loss_op(block): loss_ops = [] for op in block.ops: if is_loss_op(op): - assert ( - len(op.desc.output_arg_names()) == 1 - ), "loss op should only output loss var" + assert len(op.desc.output_arg_names()) == 1, ( + "loss op should only output loss var" + ) loss_ops.append(op) assert len(loss_ops) == 1, "num of loss op is not equal to one" @@ -1448,9 +1448,9 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): - assert ( - mapping == -1 - ), f"{op_desc.type()} only the batch dimension (0-dim) can be sharded, but the dimension {idx} is sharded by {mapping} part." + assert mapping == -1, ( + f"{op_desc.type()} only the batch dimension (0-dim) can be sharded, but the dimension {idx} is sharded by {mapping} part." + ) if len(dims_mapping) >= 1: batch_dim_mappings.append(dims_mapping[0]) for arg_name in op_desc.output_arg_names(): @@ -1461,26 +1461,26 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): - assert ( - mapping == -1 - ), f"{op_desc.type()} only the batch dimension (0-dim) can be sharded, but the dimension {idx} is sharded by {mapping} part." + assert mapping == -1, ( + f"{op_desc.type()} only the batch dimension (0-dim) can be sharded, but the dimension {idx} is sharded by {mapping} part." + ) if len(dims_mapping) >= 1: batch_dim_mappings.append(dims_mapping[0]) else: - assert ( - dims_mapping[0] == -1 - ), f"{op_desc.type()} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {mapping} part." + assert dims_mapping[0] == -1, ( + f"{op_desc.type()} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {mapping} part." + ) if len(dims_mapping) > 2: for idx, mapping in enumerate(dims_mapping[2:]): - assert ( - mapping == -1 - ), f"{op_desc.type()} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {idx} is sharded by {mapping} part." + assert mapping == -1, ( + f"{op_desc.type()} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {idx} is sharded by {mapping} part." + ) batch_dim_mappings.append(dims_mapping[1]) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) - assert ( - compatible_dim_mapping is not None - ), "There is no compatible dim mapping." + assert compatible_dim_mapping is not None, ( + "There is no compatible dim mapping." + ) for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: @@ -1543,9 +1543,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): dims_mapping_list.append(dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) - assert ( - compatible_dims_mapping is not None - ), "There is no compatible dim mapping." + assert compatible_dims_mapping is not None, ( + "There is no compatible dim mapping." + ) for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: @@ -1681,9 +1681,9 @@ def _compute_runtime(op_cost, op, vars): lambda x, y: x * y, var.shape ) break - assert ( - total_static_input_size > 0 and total_actual_input_size > 0 - ), "Get input size failed." + assert total_static_input_size > 0 and total_actual_input_size > 0, ( + "Get input size failed." + ) actual_runtime = ( total_actual_input_size / total_static_input_size * runtime @@ -2196,21 +2196,21 @@ def insert_dependencies_for_two_ops( if is_sequential_run(): return - assert ( - len(prior_op.output_arg_names) >= 1 - ), f"first op of dependency should at least have one output. [{prior_op}]" - assert ( - len(posterior_op.input_arg_names) >= 1 - ), f"second op of dependency should at least have one input. [{posterior_op}]" + assert len(prior_op.output_arg_names) >= 1, ( + f"first op of dependency should at least have one output. [{prior_op}]" + ) + assert len(posterior_op.input_arg_names) >= 1, ( + f"second op of dependency should at least have one input. [{posterior_op}]" + ) prior_op_mesh = dist_context.get_op_dist_attr_for_program( prior_op ).process_mesh posterior_mesh = dist_context.get_op_dist_attr_for_program( posterior_op ).process_mesh - assert ( - prior_op_mesh == posterior_mesh - ), f"two ops of dependency should have same mesh but got [{prior_op_mesh}] and [{posterior_mesh}]" + assert prior_op_mesh == posterior_mesh, ( + f"two ops of dependency should have same mesh but got [{prior_op_mesh}] and [{posterior_mesh}]" + ) def _select_best_depend_var(vars): # parameter should not be dep var since it maybe partition in sharding pass @@ -2431,9 +2431,9 @@ def get_pp_stage_by_process_mesh(process_mesh, pp_degree): if pp_stage_for_process_mesh is not None: if pp_stage != pp_stage_for_process_mesh: return None - assert ( - pp_stage == pp_stage_for_process_mesh - ), f"Can't get pp_stage by process_mesh with different pp_stage {pp_stage} and {pp_stage_for_process_mesh}" + assert pp_stage == pp_stage_for_process_mesh, ( + f"Can't get pp_stage by process_mesh with different pp_stage {pp_stage} and {pp_stage_for_process_mesh}" + ) pp_stage_for_process_mesh = pp_stage return pp_stage_for_process_mesh @@ -2643,15 +2643,15 @@ def fuse_param_func( if is_qkv: # fuse_attention_qkv - assert ( - num_heads - ), f"num_heads should be number of heads for Q, but got {num_heads}" - assert ( - num_key_value_heads - ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" - assert ( - len(fuse_params) == 3 - ), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}" + assert num_heads, ( + f"num_heads should be number of heads for Q, but got {num_heads}" + ) + assert num_key_value_heads, ( + f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" + ) + assert len(fuse_params) == 3, ( + f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}" + ) num_query_groups = num_heads // num_key_value_heads q_list = split_fn(fuse_params[0], num_heads, axis=-1) k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1) @@ -2705,12 +2705,12 @@ def split_param_func( if is_qkv: # fuse_attention_qkv - assert ( - num_heads - ), f"num_heads should be number of heads for Q, but got {num_heads}" - assert ( - num_key_value_heads - ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" + assert num_heads, ( + f"num_heads should be number of heads for Q, but got {num_heads}" + ) + assert num_key_value_heads, ( + f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" + ) num_query_groups = num_heads // num_key_value_heads q_list, k_list, v_list = [], [], [] split_heads = split_fn( diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index 3eb60257522971..c0232e68f66060 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -69,9 +69,9 @@ def get_best( if buffer is not None: if buffer < 0: raise ValueError("The buffer should be not less than 0.") - assert ( - max_mem_usage is not None - ), "max_mem_usage cannot be None when buffer is greater than 0." + assert max_mem_usage is not None, ( + "max_mem_usage cannot be None when buffer is greater than 0." + ) if max_mem_usage <= 0: raise ValueError("max_mem_usage should be greater than 0.") diff --git a/python/paddle/distributed/auto_tuner/search.py b/python/paddle/distributed/auto_tuner/search.py index c4eeb7c493100f..03e6b03433fa76 100644 --- a/python/paddle/distributed/auto_tuner/search.py +++ b/python/paddle/distributed/auto_tuner/search.py @@ -103,9 +103,9 @@ def __init__(self, tuner_cfg): ) tuner_cfg["candidates"]["dp_degree"] = [1] self.all_tasks = search_by_dp_estimation(tuner_cfg) - assert ( - len(self.all_tasks) > 0 - ), "Unable to perform single dp estimation search." + assert len(self.all_tasks) > 0, ( + "Unable to perform single dp estimation search." + ) def search_once(self, history_cfgs): new_cfg = None @@ -146,9 +146,9 @@ def __init__(self, tuner_cfg): super().__init__(tuner_cfg) self.idx = 0 self.configs_csv = tuner_cfg.get("configs_csv", None) - assert os.path.exists( - self.configs_csv - ), "configs_csv file is necessary in CustomizeSearch mode." + assert os.path.exists(self.configs_csv), ( + "configs_csv file is necessary in CustomizeSearch mode." + ) self.all_tasks = load_configs_from_csv(self.configs_csv) def search_once(self, history_cfgs): diff --git a/python/paddle/distributed/auto_tuner/utils.py b/python/paddle/distributed/auto_tuner/utils.py index 50ea755e933d14..bc9cf2c8436504 100644 --- a/python/paddle/distributed/auto_tuner/utils.py +++ b/python/paddle/distributed/auto_tuner/utils.py @@ -1820,7 +1820,9 @@ def load_configs_from_csv(configs_csv): recompute_granularity == "" or recompute_granularity.lower() in __SUPPORTED_RECOMPUTE_GRANULARITY__ - ), f"{recompute_granularity} must be one of {__SUPPORTED_RECOMPUTE_GRANULARITY__}, but got {recompute_granularity}." + ), ( + f"{recompute_granularity} must be one of {__SUPPORTED_RECOMPUTE_GRANULARITY__}, but got {recompute_granularity}." + ) config["recompute_granularity"] = ( recompute_granularity if recompute_granularity != "" else None ) From d6a02d56da85484cffc82b70dbcbfc55f4b74882 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 19 Aug 2025 02:29:11 +0800 Subject: [PATCH 2/2] revert python/paddle/_paddle_docs.py changes --- python/paddle/_paddle_docs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index 9c54b16d825b8f..a5b76559dce62d 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -74,6 +74,7 @@ def _parse_function_signature( if func_def.args.defaults and len(func_def.args.defaults) > ( len(func_def.args.args) - len(func_def.args.defaults) ): + idx = count - ( len(func_def.args.args) - len(func_def.args.defaults) )