Skip to content
6 changes: 4 additions & 2 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def read_local_dataset(path):
def main():
# Arguments
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file(
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(
json_file=os.path.abspath(sys.argv[1])
Comment thread
greycooker marked this conversation as resolved.
Outdated
)
else:
Expand Down
8 changes: 6 additions & 2 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,12 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:

def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

Expand Down
46 changes: 46 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,52 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
outputs.append(obj)
return (*outputs,)

def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
file.

When there is a conflict between the command line arguments and the JSON file configuration,
the command line arguments will take precedence.

This method combines data from a JSON file and command line arguments to populate instances of dataclasses.

Args:
json_file :
The path to the JSON formatted file should be at index position 1 in the command line
arguments array (sys.argv[1]).
Any JSON file path at other positions will be considered invalid.

Returns:
Tuple consisting of:

- the dataclass instances in the same order as they were passed to the initializer.abspath
"""
json_args = json.loads(Path(json_file).read_text())
del sys.argv[1]
output_dir_arg = next(
(arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None
)
if output_dir_arg is None:
if "output_dir" in json_args.keys():
sys.argv.extend(["--output_dir", json_args["output_dir"]])
else:
raise ValueError("The following arguments are required: --output_dir")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里要 特判 output_dir ?

Copy link
Copy Markdown
Contributor Author

@greycooker greycooker Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不特判output_dir,出现json文件里有output_dir参数,但是命令行里没有的情况,执行281行vars(self.parse_args())的时候就会报错,但是我们现在不希望让它报错,所以进行了output_dir的特判

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaddleNLP/blob/d89c01130a7f27c39d762cefb15926c4c69aa711/paddlenlp/trainer/argparser.py#L177C9-L177C36

你参考一下这个函数,这个函数也是一样的支持本地文件。
看看这个是怎么处理的。

这个作为通用的parser,在这里做output_dir之类的特判是不太合理的。

cmd_args = vars(self.parse_args())
merged_args = {}
for key in json_args.keys() | cmd_args.keys():
if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv):
merged_args[key] = cmd_args.get(key)
elif json_args.get(key):
merged_args[key] = json_args.get(key)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in merged_args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
Expand Down
125 changes: 125 additions & 0 deletions tests/trainer/test_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import patch

from llm.run_pretrain import PreTrainingArguments
from paddlenlp.trainer.argparser import PdArgumentParser


def parse_args():
parser = PdArgumentParser((PreTrainingArguments,))
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1]))
else:
model_args = parser.parse_args_into_dataclasses()
return model_args


def create_json_from_dict(data_dict, file_path):
with open(file_path, "w") as f:
json.dump(data_dict, f)


class ArgparserTest(unittest.TestCase):
script_name = "test_argparser.py"
args_dict = {
"max_steps": 3000,
"amp_master_grad": False,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_epsilon": 1e-08,
"bf16": False,
"enable_linear_fused_grad_add": False,
"eval_steps": 3216,
"flatten_param_grads": False,
"fp16": 1,
"log_on_each_node": True,
"logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4",
"logging_first_step": False,
"logging_steps": 1,
"lr_end": 1e-07,
"max_evaluate_steps": -1,
"max_grad_norm": 1.0,
"min_learning_rate": 3e-06,
"no_cuda": False,
"num_cycles": 0.5,
"num_train_epochs": 3.0,
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
}

def test_parse_cmd_lines(self):
cmd_line_args = [ArgparserTest.script_name]
for key, value in ArgparserTest.args_dict.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)

def test_parse_json_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(ArgparserTest.args_dict, tmpfile.name)
tmpfile_path = tmpfile.name
with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines(self):
half_size = len(ArgparserTest.args_dict) // 2
json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]}
cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(json_part, tmpfile.name)
tmpfile_path = tmpfile.name
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
for key, value in cmd_line_part.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines_with_conflict(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
json.dump(ArgparserTest.args_dict, tmpfile)
tmpfile_path = tmpfile.name
cmd_line_args = [
ArgparserTest.script_name,
tmpfile_path,
"--min_learning_rate",
"2e-5",
"--max_steps",
"3000",
"--log_on_each_node",
"False",
]
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
self.assertEqual(model_args.get("min_learning_rate"), 2e-5)
self.assertEqual(model_args.get("max_steps"), 3000)
self.assertEqual(model_args.get("log_on_each_node"), False)
for key, value in ArgparserTest.args_dict.items():
if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]:
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)