Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ repos:

# | python/_.+

# | test/a.+
| test/a.+

# | test/[b-h].+

Expand Down Expand Up @@ -153,7 +153,7 @@ repos:

| python/_.+

| test/a.+
# | test/a.+

| test/[b-h].+

Expand Down
4 changes: 1 addition & 3 deletions test/auto_parallel/PP_Schedules_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,7 @@ def test_FthenB_align_mode_of_GradientClipByGlobalNorm(self):
parameters=self.model.parameters(),
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
)
if (
dist.in_auto_parallel_align_mode()
): # When in auto parallel align mode, patching the optimizer step function
if dist.in_auto_parallel_align_mode(): # When in auto parallel align mode, patching the optimizer step function
orig_step = (
opt.step.__func__ if hasattr(opt.step, "__func__") else opt.step
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self):
self._seed = eval(os.getenv("seed"))

def check_placements(self, output, expected_placements):
assert (
output.placements == expected_placements
), f"{output.placements} vs {expected_placements}"
assert output.placements == expected_placements, (
f"{output.placements} vs {expected_placements}"
)

def test_custom_relu(self):
shapes = [16, 4, 4]
Expand Down
12 changes: 6 additions & 6 deletions test/auto_parallel/dtensor_from_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def _check_mesh(grad):
if mesh is None and placements is None:
assert not grad.is_dist(), "grad.is_dist() is not False"
else:
assert (
grad.process_mesh == mesh
), "grad.process_mesh is not equal to mesh"
assert (
grad.placements == placements
), "grad.placements is not equal to placements"
assert grad.process_mesh == mesh, (
"grad.process_mesh is not equal to mesh"
)
assert grad.placements == placements, (
"grad.placements is not equal to placements"
)

return _check_mesh

Expand Down
10 changes: 6 additions & 4 deletions test/auto_parallel/hybrid_strategy/parallel_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __init__(self):
) or (
self.config.context_parallel is False
and self.config.sep_parallel is True
), "when sep > 1, either context_parallel or sep_parallel should be true"
), (
"when sep > 1, either context_parallel or sep_parallel should be true"
)
num_hidden_layers = os.getenv("num_hidden_layers")
if num_hidden_layers:
self.config.num_hidden_layers = int(num_hidden_layers)
Expand Down Expand Up @@ -299,9 +301,9 @@ def check_lora(self, layer):
) and not self.share_embedding:
assert sub_layer.weight.stop_gradient
if 'o_proj' in name:
assert (
sub_layer.weight.stop_gradient
), f'{name} , {sub_layer.weight.name} , {sub_layer.weight}'
assert sub_layer.weight.stop_gradient, (
f'{name} , {sub_layer.weight.name} , {sub_layer.weight}'
)
assert not sub_layer.lora_A.stop_gradient
assert not sub_layer.lora_B.stop_gradient
# assert sub_layer.bias.stop_gradient is None
Expand Down
4 changes: 3 additions & 1 deletion test/auto_parallel/hybrid_strategy/semi_auto_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.init_dist_env()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.run_step = 10
self.run_step_dy2static = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.init_dist_env()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.init_dist_env()

Expand Down
44 changes: 24 additions & 20 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.init_dist_env()

Expand All @@ -136,41 +138,43 @@ def init_dist_env(self):
random.seed(1024)

def check_program_equal(self, program_a, program_b):
assert (
program_a.num_ops() == program_b.num_ops()
), f'The number of ops between two programs is different: {program_a.num_ops()} vs {program_b.num_ops()}.'
assert program_a.num_ops() == program_b.num_ops(), (
f'The number of ops between two programs is different: {program_a.num_ops()} vs {program_b.num_ops()}.'
)
for i in range(program_a.num_ops()):
a_op = program_a.global_block().ops[i]
b_op = program_a.global_block().ops[i]
# check op name
assert (
a_op.name() == b_op.name()
), f'The name of {i} op in program is different: {a_op.name()} vs {b_op.name()}.'
assert a_op.name() == b_op.name(), (
f'The name of {i} op in program is different: {a_op.name()} vs {b_op.name()}.'
)
# check op inputs
for index in range(a_op.num_operands()):
assert (
a_op.operand(index)
.source()
.is_same(b_op.operand(index).source())
), f'The type of {index} operand is different: {a_op.operand(index).source()} vs {b_op.operand(index).source()}'
), (
f'The type of {index} operand is different: {a_op.operand(index).source()} vs {b_op.operand(index).source()}'
)
# check op outputs
for index in range(a_op.num_results()):
assert a_op.result(index).is_same(
b_op.result(index)
), f'The type of {index} result is different: {a_op.result(index)} vs {b_op.result(index)}'
assert a_op.result(index).is_same(b_op.result(index)), (
f'The type of {index} result is different: {a_op.result(index)} vs {b_op.result(index)}'
)
# check op attrs
for k, v in a_op.attrs().items():
assert (
k in b_op.attrs()
), f'Can not find key of {k} attribute in other program'
assert k in b_op.attrs(), (
f'Can not find key of {k} attribute in other program'
)
if k == 'place':
assert type(v) == type(
b_op.attrs()[k]
), f'The attribute of {k} is different: {type(v)} vs {type(b_op.attrs()[k])}'
assert type(v) == type(b_op.attrs()[k]), (
f'The attribute of {k} is different: {type(v)} vs {type(b_op.attrs()[k])}'
)
else:
assert (
v == b_op.attrs()[k]
), f'The attribute of {k} is different: {v} vs {b_op.attrs()[k]}'
assert v == b_op.attrs()[k], (
f'The attribute of {k} is different: {v} vs {b_op.attrs()[k]}'
)

def run_dy2static(self, tmp_ckpt_path):
model = LlamaForCausalLMAuto(self.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ def check_tensor_eq(self, a, b, rtol=1e-05, atol=0, verbose=True):
)

def check_placements(self, output, expected_placements):
assert (
output.placements == expected_placements
), f"{output.placements} vs {expected_placements}"
assert output.placements == expected_placements, (
f"{output.placements} vs {expected_placements}"
)

def get_shard_check_hook(self, dims_mapping, check_input=False):
def check_func(layer, input, output=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,13 +990,13 @@ def split_sequence_dim(inputs):

if sep_degree > 1:
assert inputs.is_dist(), "Input tensor must be a distributed tensor."
assert (
len(inputs.shape) == 2
), f"input_ids should be [batch_size, seq_len], but got {inputs.shape}"
assert len(inputs.shape) == 2, (
f"input_ids should be [batch_size, seq_len], but got {inputs.shape}"
)
_, seq_len = inputs.shape
assert (
seq_len % sep_degree == 0
), f"sequence length {seq_len} must be divisible by cp degree {sep_degree}"
assert seq_len % sep_degree == 0, (
f"sequence length {seq_len} must be divisible by cp degree {sep_degree}"
)
# split sequence dim
placements[sep_index] = dist.Shard(1)
split_input = dist.reshard(inputs, process_mesh, placements)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,27 @@ def test_dygraph_save_static_load(self):
state_dict_to_load = dist_model.state_dict(mode="param")
assert len(state_dict_to_load) == len(expected_state_dict)
for k, v in state_dict_to_load.items():
assert (
k in expected_state_dict
), f"key {k} not in expected_state_dict:{expected_state_dict}"
assert k in expected_state_dict, (
f"key {k} not in expected_state_dict:{expected_state_dict}"
)
assert np.any(
np.not_equal(
v._local_value().numpy(),
expected_state_dict[k].numpy(),
)
), f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}"
), (
f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}"
)

dist.load_state_dict(state_dict_to_load, ckpt_path)
dist_model.set_state_dict(state_dict_to_load)

program_state_dict = dist_model.state_dict(mode="param")
assert len(expected_state_dict) == len(program_state_dict)
for k, v in program_state_dict.items():
assert (
k in expected_state_dict
), f"key {k} not in expected_state_dict:{expected_state_dict}"
assert k in expected_state_dict, (
f"key {k} not in expected_state_dict:{expected_state_dict}"
)
np.testing.assert_equal(
v._local_value().numpy(),
expected_state_dict[k].numpy(),
Expand Down Expand Up @@ -189,25 +191,27 @@ def test_static_save_dynamic_load(self):
state_dict_to_load = dy_layer.state_dict()
assert len(state_dict_to_load) == len(expected_state_dict)
for k, v in state_dict_to_load.items():
assert (
k in expected_state_dict
), f"key {k} not in expected_state_dict:{expected_state_dict}"
assert k in expected_state_dict, (
f"key {k} not in expected_state_dict:{expected_state_dict}"
)
assert np.any(
np.not_equal(
v._local_value().numpy(),
expected_state_dict[k].numpy(),
)
), f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}"
), (
f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}"
)

dist.load_state_dict(state_dict_to_load, ckpt_path)
dy_layer.set_state_dict(state_dict_to_load)

state_dict = dy_layer.state_dict()
assert len(expected_state_dict) == len(state_dict)
for k, v in state_dict.items():
assert (
k in expected_state_dict
), f"key {k} not in expected_state_dict:{expected_state_dict}"
assert k in expected_state_dict, (
f"key {k} not in expected_state_dict:{expected_state_dict}"
)
np.testing.assert_equal(
v._local_value().numpy(),
expected_state_dict[k].numpy(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_dp_mp_demo_net(self):
for k, v in state_dict.items():
assert v.numpy().sum() == 0.0, f"state_dict {k} is not zero"
assert k in need_load_state_dict, f"state_dict {k} is not found"
assert (
need_load_state_dict[k].numpy().sum() == 0.0
), f"state_dict {k} is not zero"
assert need_load_state_dict[k].numpy().sum() == 0.0, (
f"state_dict {k} is not zero"
)

paddle.distributed.load_state_dict(
need_load_state_dict, self._ckpt_path
Expand Down
30 changes: 15 additions & 15 deletions test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,27 @@ def check_structure_name_mapping(ckpt_path, state_dict):
data_file_path = os.path.join(
ckpt_path, f"{paddle.distributed.get_rank()}_0.distcp"
)
assert os.path.exists(
metadata_file_path
), f"metadata file {metadata_file_path} is not found"
assert os.path.exists(
data_file_path
), f"data file {data_file_path} is not found"
assert os.path.exists(metadata_file_path), (
f"metadata file {metadata_file_path} is not found"
)
assert os.path.exists(data_file_path), (
f"data file {data_file_path} is not found"
)
metadata = paddle.load(metadata_file_path)
cur_rank_state_dict = paddle.load(data_file_path, keep_name_table=True)
local_structure_name_mapping = cur_rank_state_dict.pop(
"StructuredToParameterName@@"
)
assert isinstance(
local_structure_name_mapping, dict
), f"local_structure_name_mapping:{local_structure_name_mapping} is not dict type"
assert isinstance(local_structure_name_mapping, dict), (
f"local_structure_name_mapping:{local_structure_name_mapping} is not dict type"
)
for structure_name, param_name in local_structure_name_mapping.items():
assert (
structure_name in state_dict
), f"tensor key:{structure_name} is not found in state dict:{state_dict}"
assert (
param_name == state_dict[structure_name].name
), f"param name:{param_name} is not equal to param name in state_dict:{state_dict[structure_name].name}"
assert structure_name in state_dict, (
f"tensor key:{structure_name} is not found in state dict:{state_dict}"
)
assert param_name == state_dict[structure_name].name, (
f"param name:{param_name} is not equal to param name in state_dict:{state_dict[structure_name].name}"
)


class TestSaveStateDict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def test_single_schedule(self, sing_schedule="FThenB"):
cur_rank = dist.get_rank()
stage_layers = SingleStage(
self.model.linears[
cur_rank
* num_layers_per_card : (cur_rank + 1)
cur_rank * num_layers_per_card : (cur_rank + 1)
* num_layers_per_card
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __init__(self):
assert (
self.config.sep_parallel_degree
!= self.config.context_parallel_degree
), f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
), (
f"only one of the context_parallel and sep_parallel can be True, but get context_parallel_degree = {self.config.context_parallel_degree} and sep_parallel_degree = {self.config.sep_parallel_degree}, please check your env"
)

self.strategy = dist.Strategy()

Expand Down
Loading