-
Notifications
You must be signed in to change notification settings - Fork 3k
Trainer support simultaneously parse JSON files and cmd arguments. #7768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
5dc3ab5
e348363
2cede49
286ca2f
6652ed6
36c50e0
d89c011
52cb70d
6321165
9f8e8d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -247,6 +247,52 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: | |
| outputs.append(obj) | ||
| return (*outputs,) | ||
|
|
||
| def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...]: | ||
| """ | ||
| Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON | ||
| file. | ||
|
|
||
| When there is a conflict between the command line arguments and the JSON file configuration, | ||
| the command line arguments will take precedence. | ||
|
|
||
| This method combines data from a JSON file and command line arguments to populate instances of dataclasses. | ||
|
|
||
| Args: | ||
| json_file : | ||
| The path to the JSON formatted file should be at index position 1 in the command line | ||
| arguments array (sys.argv[1]). | ||
| Any JSON file path at other positions will be considered invalid. | ||
|
|
||
| Returns: | ||
| Tuple consisting of: | ||
|
|
||
| - the dataclass instances in the same order as they were passed to the initializer.abspath | ||
| """ | ||
| json_args = json.loads(Path(json_file).read_text()) | ||
| del sys.argv[1] | ||
| output_dir_arg = next( | ||
| (arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None | ||
| ) | ||
| if output_dir_arg is None: | ||
| if "output_dir" in json_args.keys(): | ||
| sys.argv.extend(["--output_dir", json_args["output_dir"]]) | ||
| else: | ||
| raise ValueError("The following arguments are required: --output_dir") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么这里要 特判 output_dir ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果不特判output_dir,出现json文件里有output_dir参数,但是命令行里没有的情况,执行281行vars(self.parse_args())的时候就会报错,但是我们现在不希望让它报错,所以进行了output_dir的特判
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你参考一下这个函数,这个函数也是一样的支持本地文件。 这个作为通用的parser,在这里做 |
||
| cmd_args = vars(self.parse_args()) | ||
| merged_args = {} | ||
| for key in json_args.keys() | cmd_args.keys(): | ||
| if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv): | ||
| merged_args[key] = cmd_args.get(key) | ||
| elif json_args.get(key): | ||
| merged_args[key] = json_args.get(key) | ||
| outputs = [] | ||
| for dtype in self.dataclass_types: | ||
| keys = {f.name for f in dataclasses.fields(dtype) if f.init} | ||
| inputs = {k: v for k, v in merged_args.items() if k in keys} | ||
| obj = dtype(**inputs) | ||
| outputs.append(obj) | ||
| return (*outputs,) | ||
|
|
||
| def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: | ||
| """ | ||
| Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import json | ||
| import os | ||
| import sys | ||
| import tempfile | ||
| import unittest | ||
| from unittest.mock import patch | ||
|
|
||
| from llm.run_pretrain import PreTrainingArguments | ||
| from paddlenlp.trainer.argparser import PdArgumentParser | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = PdArgumentParser((PreTrainingArguments,)) | ||
| # Support format as "args.json --args1 value1 --args2 value2.” | ||
| # In case of conflict, command line arguments take precedence. | ||
| if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): | ||
| model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1])) | ||
| else: | ||
| model_args = parser.parse_args_into_dataclasses() | ||
| return model_args | ||
|
|
||
|
|
||
| def create_json_from_dict(data_dict, file_path): | ||
| with open(file_path, "w") as f: | ||
| json.dump(data_dict, f) | ||
|
|
||
|
|
||
| class ArgparserTest(unittest.TestCase): | ||
| script_name = "test_argparser.py" | ||
| args_dict = { | ||
| "max_steps": 3000, | ||
| "amp_master_grad": False, | ||
| "adam_beta1": 0.9, | ||
| "adam_beta2": 0.999, | ||
| "adam_epsilon": 1e-08, | ||
| "bf16": False, | ||
| "enable_linear_fused_grad_add": False, | ||
| "eval_steps": 3216, | ||
| "flatten_param_grads": False, | ||
| "fp16": 1, | ||
| "log_on_each_node": True, | ||
| "logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4", | ||
| "logging_first_step": False, | ||
| "logging_steps": 1, | ||
| "lr_end": 1e-07, | ||
| "max_evaluate_steps": -1, | ||
| "max_grad_norm": 1.0, | ||
| "min_learning_rate": 3e-06, | ||
| "no_cuda": False, | ||
| "num_cycles": 0.5, | ||
| "num_train_epochs": 3.0, | ||
| "output_dir": "./checkpoints/llama2_pretrain_ckpts", | ||
| } | ||
|
|
||
| def test_parse_cmd_lines(self): | ||
| cmd_line_args = [ArgparserTest.script_name] | ||
| for key, value in ArgparserTest.args_dict.items(): | ||
| cmd_line_args.extend([f"--{key}", str(value)]) | ||
| with patch("sys.argv", cmd_line_args): | ||
| model_args = vars(parse_args()[0]) | ||
| for key, value in ArgparserTest.args_dict.items(): | ||
| self.assertEqual(model_args.get(key), value) | ||
|
|
||
| def test_parse_json_file(self): | ||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: | ||
| create_json_from_dict(ArgparserTest.args_dict, tmpfile.name) | ||
| tmpfile_path = tmpfile.name | ||
| with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]): | ||
| model_args = vars(parse_args()[0]) | ||
| for key, value in ArgparserTest.args_dict.items(): | ||
| self.assertEqual(model_args.get(key), value) | ||
| os.remove(tmpfile_path) | ||
|
|
||
| def test_parse_json_file_and_cmd_lines(self): | ||
| half_size = len(ArgparserTest.args_dict) // 2 | ||
| json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]} | ||
| cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]} | ||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: | ||
| create_json_from_dict(json_part, tmpfile.name) | ||
| tmpfile_path = tmpfile.name | ||
| cmd_line_args = [ArgparserTest.script_name, tmpfile_path] | ||
| for key, value in cmd_line_part.items(): | ||
| cmd_line_args.extend([f"--{key}", str(value)]) | ||
| with patch("sys.argv", cmd_line_args): | ||
| model_args = vars(parse_args()[0]) | ||
| for key, value in ArgparserTest.args_dict.items(): | ||
| self.assertEqual(model_args.get(key), value) | ||
| os.remove(tmpfile_path) | ||
|
|
||
| def test_parse_json_file_and_cmd_lines_with_conflict(self): | ||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: | ||
| json.dump(ArgparserTest.args_dict, tmpfile) | ||
| tmpfile_path = tmpfile.name | ||
| cmd_line_args = [ | ||
| ArgparserTest.script_name, | ||
| tmpfile_path, | ||
| "--min_learning_rate", | ||
| "2e-5", | ||
| "--max_steps", | ||
| "3000", | ||
| "--log_on_each_node", | ||
| "False", | ||
| ] | ||
| with patch("sys.argv", cmd_line_args): | ||
| model_args = vars(parse_args()[0]) | ||
| self.assertEqual(model_args.get("min_learning_rate"), 2e-5) | ||
| self.assertEqual(model_args.get("max_steps"), 3000) | ||
| self.assertEqual(model_args.get("log_on_each_node"), False) | ||
| for key, value in ArgparserTest.args_dict.items(): | ||
| if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]: | ||
| self.assertEqual(model_args.get(key), value) | ||
| os.remove(tmpfile_path) |
Uh oh!
There was an error while loading. Please reload this page.