Skip to content

Commit e06a9e4

Browse files
gouzilLuckycheng222
authored andcommitted
[CodeStyle] black -> ruff format migration - part 31 (PaddlePaddle#74745)
1 parent 9a35802 commit e06a9e4

32 files changed

+281
-271
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ repos:
8787
8888
| python/paddle/[e-i].+
8989
90-
# | python/paddle/j.+
90+
| python/paddle/j.+
9191
9292
| python/paddle/[k-n].+
9393
@@ -143,7 +143,7 @@ repos:
143143
144144
# | python/paddle/[e-i].+
145145
146-
| python/paddle/j.+
146+
# | python/paddle/j.+
147147
148148
# | python/paddle/[k-n].+
149149

python/paddle/jit/dy2static/convert_operators.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,13 +756,17 @@ def convert_var_dtype(var, dtype):
756756
'int32',
757757
'int64',
758758
'uint8',
759-
], f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op."
759+
], (
760+
f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op."
761+
)
760762
assert dtype in [
761763
'bool',
762764
'int',
763765
'float',
764766
'complex',
765-
], f"The casted target dtype is {dtype}, which is not supported in type casting."
767+
], (
768+
f"The casted target dtype is {dtype}, which is not supported in type casting."
769+
)
766770
cast_map = {
767771
'bool': 'bool',
768772
'int': 'int32',
@@ -776,7 +780,9 @@ def convert_var_dtype(var, dtype):
776780
'int',
777781
'float',
778782
'complex',
779-
], f"The casted target dtype is {dtype}, which is not supported in type casting."
783+
], (
784+
f"The casted target dtype is {dtype}, which is not supported in type casting."
785+
)
780786
return eval(dtype)(var)
781787

782788

python/paddle/jit/dy2static/origin_info.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def create_and_update_origin_info_map(
155155
static_node = attach_origin_info(static_node, static_func)
156156

157157
for t_node, s_node in ast_walk(transformed_node, static_node):
158-
assert type(t_node) == type(
159-
s_node
160-
), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
158+
assert type(t_node) == type(s_node), (
159+
f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
160+
)
161161
dygraph_info = getattr(t_node, ORIGIN_INFO, None)
162162
static_info = getattr(s_node, ORIGIN_INFO, None)
163163

@@ -232,9 +232,9 @@ def _as_list(x):
232232
):
233233
continue
234234

235-
assert type(t_node) == type(
236-
s_node
237-
), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
235+
assert type(t_node) == type(s_node), (
236+
f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
237+
)
238238

239239
yield t_node, s_node
240240

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@ def __init__(
218218
forward_range=None,
219219
backward_range=None,
220220
):
221-
assert isinstance(
222-
in_out_values, tuple
223-
), "in_out_values must be tuple with len == 3"
224-
assert (
225-
len(in_out_values) == 3
226-
), "in_out_values must be tuple with len == 3"
227-
assert isinstance(
228-
in_out_values[0], list
229-
), "in_out_values must be tuple with len == 3"
221+
assert isinstance(in_out_values, tuple), (
222+
"in_out_values must be tuple with len == 3"
223+
)
224+
assert len(in_out_values) == 3, (
225+
"in_out_values must be tuple with len == 3"
226+
)
227+
assert isinstance(in_out_values[0], list), (
228+
"in_out_values must be tuple with len == 3"
229+
)
230230
self.program = program
231231
self.x_names = self.convert_name(in_out_values[0])
232232
self.param_names = self.convert_name(in_out_values[1])
@@ -310,9 +310,9 @@ def clone(self):
310310
)
311311

312312
def split_forward_backward(self):
313-
assert (
314-
self.has_splited is False
315-
), "Please ensure only split once! don't call split_forward_backward manually."
313+
assert self.has_splited is False, (
314+
"Please ensure only split once! don't call split_forward_backward manually."
315+
)
316316
self.has_splited = True
317317
self.update_op_range()
318318
(
@@ -406,9 +406,9 @@ def _forward_backward_program(self):
406406

407407
@cached_property # shouldn't changed when call this once.
408408
def program_attr(self):
409-
assert (
410-
self.finish_pass is False
411-
), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
409+
assert self.finish_pass is False, (
410+
"program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
411+
)
412412
# can't apply pass after call this function.
413413
self.finish_pass = True
414414
fwd_map = RunnableProgram._get_name_value_map_from_program(
@@ -445,9 +445,9 @@ def program_attr(self):
445445
program_attr[f"{k}_names"] = ns
446446

447447
# Restore stop_gradient for output values
448-
assert len(program_attr["fo_values"]) == len(
449-
self.out_stop_gradients
450-
), "Output values and stop gradients length mismatch"
448+
assert len(program_attr["fo_values"]) == len(self.out_stop_gradients), (
449+
"Output values and stop gradients length mismatch"
450+
)
451451
for v, stop_gradient in zip(
452452
program_attr["fo_values"], self.out_stop_gradients
453453
):
@@ -474,9 +474,9 @@ def unify_value_names(
474474
# Get all values again because some values has been erased.
475475
for value in RunnableProgram._get_program_all_values(program):
476476
if value.has_name:
477-
assert (
478-
value._has_only_one_name()
479-
), f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}"
477+
assert value._has_only_one_name(), (
478+
f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}"
479+
)
480480
return rename_mapping
481481

482482
@staticmethod

python/paddle/jit/dy2static/program_translator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,9 @@ def rollback(self) -> Callable[_InputT, _RetT]:
672672
if self._patched_name is not None
673673
else self._dygraph_function.__name__
674674
)
675-
assert (
676-
fn_name in self.class_instance._original_funcs
677-
), f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'."
675+
assert fn_name in self.class_instance._original_funcs, (
676+
f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'."
677+
)
678678
func = self.class_instance._original_funcs[fn_name]
679679
setattr(self.class_instance, fn_name, func.__get__(self.class_instance))
680680
return getattr(self.class_instance, fn_name)
@@ -1733,9 +1733,9 @@ def get_program(self, item):
17331733
return self._caches[item_id]
17341734

17351735
def last(self):
1736-
assert (
1737-
len(self._caches) >= 1
1738-
), "No valid cached program in ProgramCache."
1736+
assert len(self._caches) >= 1, (
1737+
"No valid cached program in ProgramCache."
1738+
)
17391739
assert self._recent_key is not None
17401740
return self._recent_key, self._caches[self._recent_key]
17411741

python/paddle/jit/dy2static/transformers/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ class ForNodeVisitor:
184184
"""
185185

186186
def __init__(self, for_node):
187-
assert isinstance(
188-
for_node, gast.For
189-
), "Input node for the initialization of ForNodeVisitor is not gast.For node."
187+
assert isinstance(for_node, gast.For), (
188+
"Input node for the initialization of ForNodeVisitor is not gast.For node."
189+
)
190190
# 1. original for node
191191
self.node = for_node
192192

@@ -276,14 +276,14 @@ def is_for_enumerate_iter(self):
276276
def _args_check(self):
277277
if self.is_for_range_iter():
278278
self.args_length = len(self.iter_args)
279-
assert (
280-
self.args_length >= 1 and self.args_length <= 3
281-
), "range() function takes 1 to 3 arguments"
279+
assert self.args_length >= 1 and self.args_length <= 3, (
280+
"range() function takes 1 to 3 arguments"
281+
)
282282
elif self.is_for_enumerate_iter():
283283
self.args_length = len(self.iter_args)
284-
assert (
285-
self.args_length >= 1 and self.args_length <= 2
286-
), "enumerate() function takes 1 to 2 arguments"
284+
assert self.args_length >= 1 and self.args_length <= 2, (
285+
"enumerate() function takes 1 to 2 arguments"
286+
)
287287
else:
288288
self.args_length = None
289289

python/paddle/jit/dy2static/transformers/break_continue_transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ class ForToWhileTransformer(BaseTransformer):
3131
"""
3232

3333
def __init__(self, parent_node, loop_node, condition_node):
34-
assert isinstance(
35-
loop_node, gast.For
36-
), "loop_node is not gast.For in ForToWhileTransformer"
34+
assert isinstance(loop_node, gast.For), (
35+
"loop_node is not gast.For in ForToWhileTransformer"
36+
)
3737
self.parent_node = parent_node
3838
self.loop_node = loop_node
3939
self.condition_node = condition_node
@@ -60,9 +60,9 @@ def transform(self):
6060
)
6161

6262
def get_for_stmt_nodes(self, node):
63-
assert isinstance(
64-
node, gast.For
65-
), "Input node is NOT gast.For in get_for_stmt_nodes"
63+
assert isinstance(node, gast.For), (
64+
"Input node is NOT gast.For in get_for_stmt_nodes"
65+
)
6666

6767
# 1. parse current gast.For node
6868
current_for_node_parser = ForNodeVisitor(node)

python/paddle/jit/dy2static/transformers/early_return_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def transform(self):
3434
self.visit(self.root)
3535

3636
def is_define_return_in_if(self, node):
37-
assert isinstance(
38-
node, gast.If
39-
), f"Type of input node should be gast.If, but received {type(node)}."
37+
assert isinstance(node, gast.If), (
38+
f"Type of input node should be gast.If, but received {type(node)}."
39+
)
4040
for child in node.body:
4141
if isinstance(child, gast.Return):
4242
return True

python/paddle/jit/dy2static/transformers/logical_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def _create_bool_op_node(self, nodes, api_type):
8383
according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1`
8484
must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run.
8585
'''
86-
assert (
87-
len(nodes) > 1
88-
), f"The length of BoolOp should be at least 2, but received {len(nodes)}."
86+
assert len(nodes) > 1, (
87+
f"The length of BoolOp should be at least 2, but received {len(nodes)}."
88+
)
8989
if len(nodes) > 2:
9090
# Creates logic_and/logic_or node recursively.
9191
pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)

python/paddle/jit/dy2static/transformers/loop_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def __init__(self, root_node):
134134
self.visit(root_node)
135135

136136
def get_loop_var_names(self, node):
137-
assert isinstance(
138-
node, (gast.While, gast.For)
139-
), "Input node is not gast loop node"
137+
assert isinstance(node, (gast.While, gast.For)), (
138+
"Input node is not gast loop node"
139+
)
140140
loop_var_names = set()
141141
create_var_names = set()
142142
read_context = {type(gast.Load()), type(gast.AugLoad())}

0 commit comments

Comments
 (0)