Skip to content

Commit 3e27f68

Browse files
gouzilLuckycheng222
authored andcommitted
[CodeStyle] black -> ruff format migration - part 29 (PaddlePaddle#74743)
1 parent 127e1fd commit 3e27f68

39 files changed

+352
-350
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ repos:
8383
8484
# | python/paddle/distributed/f.+
8585
86-
# | python/paddle/distributed/[g-z].+
86+
| python/paddle/distributed/[g-z].+
8787
8888
# | python/paddle/[e-i].+
8989
@@ -139,7 +139,7 @@ repos:
139139
140140
| python/paddle/distributed/f.+
141141
142-
| python/paddle/distributed/[g-z].+
142+
# | python/paddle/distributed/[g-z].+
143143
144144
| python/paddle/[e-i].+
145145

python/paddle/distributed/launch/controllers/controller.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def __init__(self, ctx):
6262
self.join_server = None
6363

6464
def deploy_pod(self):
65-
assert (
66-
len(self.pod.containers) + len(self.pod.init_containers) > 0
67-
), "No container in the pod"
65+
assert len(self.pod.containers) + len(self.pod.init_containers) > 0, (
66+
"No container in the pod"
67+
)
6868

6969
self.ctx.logger.info(f"Run {self.pod}")
7070
if len(self.pod.init_containers) > 0:
@@ -309,9 +309,9 @@ def save_pod_log(self, info):
309309
self.ctx.logger.error(f"save log failed because {e}")
310310

311311
def save_pod_env(self):
312-
assert (
313-
len(self.pod.containers) + len(self.pod.init_containers) > 0
314-
), "No container in the pod"
312+
assert len(self.pod.containers) + len(self.pod.init_containers) > 0, (
313+
"No container in the pod"
314+
)
315315

316316
if not self.ctx.args.log_dir:
317317
return

python/paddle/distributed/launch/controllers/ipu_controller.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,19 @@ def replace_training_script(self):
6969

7070
num_ipus = int(self.ctx.args.devices)
7171
# The number of replicas for data parallel
72-
assert (
73-
num_ipus % poprun_args.ipus_per_replica
74-
) == 0, f"The number of IPUs:{num_ipus} mod the number of IPUs per replica:{poprun_args.ipus_per_replica} must == 0"
72+
assert (num_ipus % poprun_args.ipus_per_replica) == 0, (
73+
f"The number of IPUs:{num_ipus} mod the number of IPUs per replica:{poprun_args.ipus_per_replica} must == 0"
74+
)
7575
num_replicas = num_ipus // poprun_args.ipus_per_replica
7676
self.ctx.logger.info(f"The number of total replicas is {num_replicas}.")
7777

7878
# The number of processes
7979
num_nodes = len(poprun_args.hosts.split(','))
8080
num_procs = num_nodes * poprun_args.nproc_per_host
8181
self.ctx.logger.info(f"The number of total processes is {num_procs}.")
82-
assert (
83-
num_replicas % num_procs
84-
) == 0, f"The number of replicas:{num_replicas} mod the number of processes:{num_procs} must == 0"
82+
assert (num_replicas % num_procs) == 0, (
83+
f"The number of replicas:{num_replicas} mod the number of processes:{num_procs} must == 0"
84+
)
8585

8686
# hosts and endpoints
8787
hosts = poprun_args.hosts.replace(' ', '').split(',')

python/paddle/distributed/launch/controllers/rpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def enable(cls, ctx):
2727
return False
2828

2929
def build_pod(self):
30-
assert (
31-
self.ctx.args.master is not None
32-
), "Master is None, Please set master address!"
30+
assert self.ctx.args.master is not None, (
31+
"Master is None, Please set master address!"
32+
)
3333
self._build_pod_with_master()
3434

3535
def _build_pod_with_master(self):

python/paddle/distributed/launch/job/container.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def update_env(self, env={}, **kwargs):
9494

9595
def _validate_env(self):
9696
for k, v in self._env.items():
97-
assert isinstance(k, str) and isinstance(
98-
v, str
99-
), f'env {k}:{v} must be str'
97+
assert isinstance(k, str) and isinstance(v, str), (
98+
f'env {k}:{v} must be str'
99+
)
100100

101101
def _get_fd(self, pth):
102102
if not pth:

python/paddle/distributed/parallel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,9 @@ def __init__(
391391
) -> None:
392392
super().__init__(layers.full_name() + "_data_parallel")
393393

394-
assert (
395-
in_dynamic_mode()
396-
), "It's not supported to construct DataParallel in static graph mode."
394+
assert in_dynamic_mode(), (
395+
"It's not supported to construct DataParallel in static graph mode."
396+
)
397397

398398
self._layers = layers
399399
self.find_unused_parameters = find_unused_parameters
@@ -756,12 +756,12 @@ def __init__(self):
756756
).split(",")
757757
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
758758
self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
759-
assert (
760-
self._nrings > 0
761-
), "nccl_nrings must be an integer greater than 0."
762-
assert (
763-
self._nrings < 9
764-
), "nccl_nrings should be less than 9, which is enough in most scenarios."
759+
assert self._nrings > 0, (
760+
"nccl_nrings must be an integer greater than 0."
761+
)
762+
assert self._nrings < 9, (
763+
"nccl_nrings should be less than 9, which is enough in most scenarios."
764+
)
765765

766766
@property
767767
def rank(self) -> int:

python/paddle/distributed/parallel_helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ def _is_parallel_ctx_initialized():
3333

3434
def _set_parallel_ctx(ccl_parallel_context):
3535
global __parallel_ctx__clz__
36-
assert (
37-
__parallel_ctx__clz__ is None
38-
), "ParallelContext can only be initialized once."
36+
assert __parallel_ctx__clz__ is None, (
37+
"ParallelContext can only be initialized once."
38+
)
3939
__parallel_ctx__clz__ = ccl_parallel_context
4040

4141

4242
def _init_parallel_ctx():
4343
global __parallel_ctx__clz__
44-
assert (
45-
__parallel_ctx__clz__ is not None
46-
), "ParallelContext should be initialized."
44+
assert __parallel_ctx__clz__ is not None, (
45+
"ParallelContext should be initialized."
46+
)
4747
__parallel_ctx__clz__.init()
4848

4949

python/paddle/distributed/parallel_with_gloo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def gloo_init_parallel_env(
9696
... test_gloo_init_with_multiprocess(2)
9797
"""
9898

99-
assert (
100-
rank_num < 2
101-
) is False, "rank_num should greater than or equal to 2 for parallel environment initialization."
99+
assert (rank_num < 2) is False, (
100+
"rank_num should greater than or equal to 2 for parallel environment initialization."
101+
)
102102

103103
# init gloo context
104104
manager = Manager()

python/paddle/distributed/passes/auto_parallel_amp.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ def _cast_block(self, block):
340340
out_var = block.var(out_var_name)
341341
in_var = block._find_var_recursive(in_var_name)
342342
for in_var_name in op.input_arg_names:
343-
assert (
344-
in_var.dtype == block.var(in_var_name).dtype
345-
), f"{in_var}, {block.var(in_var_name)}, {op}"
343+
assert in_var.dtype == block.var(in_var_name).dtype, (
344+
f"{in_var}, {block.var(in_var_name)}, {op}"
345+
)
346346
out_var.desc.set_dtype(in_var.dtype)
347347
elif int(op.attr('op_role')) == 257:
348348
pass
@@ -545,9 +545,9 @@ def _keep_fp32_output(op, out_name):
545545
cast_name, in_var_dist_attr
546546
)
547547
else:
548-
assert (
549-
in_var.dtype == dst_dtype
550-
), f"op [{op.type}] expect input [{in_name}] to be dtype [{dst_dtype}] BUT got [{in_var.dtype}]. {op}"
548+
assert in_var.dtype == dst_dtype, (
549+
f"op [{op.type}] expect input [{in_name}] to be dtype [{dst_dtype}] BUT got [{in_var.dtype}]. {op}"
550+
)
551551

552552
for out_name in op.output_names:
553553
if src_dtype == paddle.float32 and _keep_fp32_output(op, out_name):
@@ -1158,13 +1158,13 @@ def _update_loss_scaling(self, grads, found_inf):
11581158
e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
11591159
)
11601160
if e.dtype == paddle.float16:
1161-
assert (
1162-
self._loss_scaling.dtype == paddle.float32
1163-
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1161+
assert self._loss_scaling.dtype == paddle.float32, (
1162+
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1163+
)
11641164
else:
1165-
assert (
1166-
self._loss_scaling.dtype == e.dtype
1167-
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
1165+
assert self._loss_scaling.dtype == e.dtype, (
1166+
"The dtype of prev_loss_scaling should be equal to the dtype of x."
1167+
)
11681168

11691169
inputs = {
11701170
'X': grads,

python/paddle/distributed/passes/auto_parallel_c_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def _update_before_dims_mapping(self, new_op):
173173
results.append(dist_attr_new)
174174
sub_name = op.name().split('.')[1]
175175
if op.num_operands() > 0:
176-
assert (
177-
sub_name != "cast"
178-
), "Need to add support for {sub_name}."
176+
assert sub_name != "cast", (
177+
"Need to add support for {sub_name}."
178+
)
179179
operands.append(dist_attr_new)
180180
next_op = op.operand(0).source().get_defining_op()
181181
stack.append(next_op)

0 commit comments

Comments
 (0)