Skip to content
14 changes: 9 additions & 5 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,16 @@ def read_local_dataset(path):
def main():
# Arguments
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")]
if len(json_indices) >= 2:
raise ValueError("Only support one file in json format at most, please check the command line parameters.")
elif len(json_indices) == 0:
gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()
else:
json_file_idx = json_indices[0]
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(
json_file_idx
)
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
training_args.print_config(quant_args, "Quant")
Expand Down
10 changes: 7 additions & 3 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,14 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:

def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")]
Comment thread
JunnYu marked this conversation as resolved.
Outdated
if len(json_indices) >= 2:
raise ValueError("Only support one file in json format at most, please check the command line parameters.")
elif len(json_indices) == 0:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
else:
json_file_idx = json_indices[0]
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(json_file_idx)

if training_args.enable_linear_fused_grad_add:
from fused_layers import mock_layers
Expand Down
45 changes: 45 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dataclasses
import json
import os
import sys
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
Expand Down Expand Up @@ -247,6 +248,50 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
outputs.append(obj)
return (*outputs,)

def parse_json_file_and_cmd_lines(self, json_file_idx: int) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
file.

This method combines data from a JSON file and command line arguments to populate instances of dataclasses.
The JSON file is identified using its index in the command line arguments array.

Args:
json_file_idx :
The index of the JSON file argument within the command line arguments array.
This index is used to locate and extract the JSON file path from the command line arguments.

Returns:
Tuple consisting of:

- the dataclass instances in the same order as they were passed to the initializer.abspath
"""
json_file = os.path.abspath(sys.argv[json_file_idx])
json_args = json.loads(Path(json_file).read_text())
del sys.argv[json_file_idx]
output_dir_arg = next(
(arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None
)
if output_dir_arg is None:
if "output_dir" in json_args.keys():
sys.argv.extend(["--output_dir", json_args["output_dir"]])
else:
raise ValueError("The following arguments are required: --output_dir")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里要 特判 output_dir ?

Copy link
Copy Markdown
Contributor Author

@greycooker greycooker Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不特判output_dir,出现json文件里有output_dir参数,但是命令行里没有的情况,执行281行vars(self.parse_args())的时候就会报错,但是我们现在不希望让它报错,所以进行了output_dir的特判

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaddleNLP/blob/d89c01130a7f27c39d762cefb15926c4c69aa711/paddlenlp/trainer/argparser.py#L177C9-L177C36

你参考一下这个函数,这个函数也是一样的支持本地文件。
看看这个是怎么处理的。

这个作为通用的parser,在这里做output_dir之类的特判是不太合理的。

cmd_args = vars(self.parse_args())
merged_args = {}
for key in json_args.keys() | cmd_args.keys():
if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv):
merged_args[key] = cmd_args.get(key)
elif json_args.get(key):
merged_args[key] = json_args.get(key)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in merged_args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
Expand Down
132 changes: 132 additions & 0 deletions tests/trainer/test_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import patch

from llm.run_pretrain import PreTrainingArguments
from paddlenlp.trainer.argparser import PdArgumentParser


def parse_args():
parser = PdArgumentParser((PreTrainingArguments,))
json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")]
if len(json_indices) >= 2:
raise ValueError("Only support one file in json format at most, please check the command line parameters.")
elif len(json_indices) == 0:
model_args = parser.parse_args_into_dataclasses()
else:
json_file_idx = json_indices[0]
model_args = parser.parse_json_file_and_cmd_lines(json_file_idx)
return model_args


def create_json_from_dict(data_dict, file_path):
with open(file_path, "w") as f:
json.dump(data_dict, f)


class ArgparserTest(unittest.TestCase):
script_name = "test_argparser.py"
args_dict = {
"max_steps": 3000,
"amp_master_grad": False,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_epsilon": 1e-08,
"bf16": False,
"enable_linear_fused_grad_add": False,
"eval_steps": 3216,
"flatten_param_grads": False,
"fp16": 1,
"log_on_each_node": True,
"logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4",
"logging_first_step": False,
"logging_steps": 1,
"lr_end": 1e-07,
"max_evaluate_steps": -1,
"max_grad_norm": 1.0,
"min_learning_rate": 3e-06,
"no_cuda": False,
"num_cycles": 0.5,
"num_train_epochs": 3.0,
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
}

def test_parse_args_with_multiple_json_files(self):
with self.assertRaises(ValueError):
with patch("sys.argv", [ArgparserTest.script_name, "config1.json", "config2.json"]):
parse_args()

def test_parse_cmd_lines(self):
cmd_line_args = [ArgparserTest.script_name]
for key, value in ArgparserTest.args_dict.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)

def test_parse_json_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(ArgparserTest.args_dict, tmpfile.name)
tmpfile_path = tmpfile.name
with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines(self):
half_size = len(ArgparserTest.args_dict) // 2
json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]}
cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(json_part, tmpfile.name)
tmpfile_path = tmpfile.name
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
for key, value in cmd_line_part.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines_with_conflict(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
json.dump(ArgparserTest.args_dict, tmpfile)
tmpfile_path = tmpfile.name
cmd_line_args = [
ArgparserTest.script_name,
tmpfile_path,
"--min_learning_rate",
"2e-5",
"--max_steps",
"3000",
"--log_on_each_node",
"False",
]
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
self.assertEqual(model_args.get("min_learning_rate"), 2e-5)
self.assertEqual(model_args.get("max_steps"), 3000)
self.assertEqual(model_args.get("log_on_each_node"), False)
for key, value in ArgparserTest.args_dict.items():
if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]:
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)