Skip to content
20 changes: 7 additions & 13 deletions llm/auto_parallel/gpt-3/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
import paddle.distributed as dist

from paddlenlp.ops import Topology
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
TrainingArguments,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds
from paddlenlp.transformers import (
Expand Down Expand Up @@ -60,7 +65,7 @@ def docstring_decorator(fn):

@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -77,12 +82,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -124,11 +123,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
40 changes: 7 additions & 33 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from paddle.distributed import fleet

from paddlenlp.ops import Topology
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
TrainingArguments,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds
from paddlenlp.transformers import (
Expand Down Expand Up @@ -63,7 +68,7 @@ def docstring_decorator(fn):

@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -80,22 +85,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
fuse_allreduce_split_to_reducescatter: bool = field(
default=False,
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
)
eliminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -137,21 +126,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.fuse_allreduce_split_to_reducescatter:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")

logger.info(self.strategy)


Expand Down
14 changes: 3 additions & 11 deletions llm/auto_parallel/llama/run_pretrain_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from paddlenlp.ops import Topology
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
Trainer,
TrainingArguments,
Expand Down Expand Up @@ -88,7 +89,7 @@ def exec_mode_guard():

@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -99,12 +100,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand All @@ -127,10 +122,7 @@ class PreTrainingArguments(TrainingArguments):
def __post_init__(self):
super().__post_init__()
assert self.enable_auto_parallel
if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
20 changes: 7 additions & 13 deletions llm/auto_parallel/qwen/run_pretrain_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
import paddle.distributed as dist
from paddle.distributed import fleet

from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
TrainingArguments,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy
from paddlenlp.transformers import (
Expand Down Expand Up @@ -61,7 +66,7 @@ def docstring_decorator(fn):

@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@add_start_docstrings(TrainingArguments.__doc__)
@add_start_docstrings(AutoTrainingArguments.__doc__)

可以这么换一下,看看是不是正常能够提示参数

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

class PreTrainingArguments(TrainingArguments):
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -78,12 +83,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -133,11 +132,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .argparser import *
from .auto_training_args import *
from .compression_args import *
from .plugins.timer import *
from .trainer import *
Expand Down
62 changes: 62 additions & 0 deletions paddlenlp/trainer/auto_training_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from .training_args import TrainingArguments
from .utils import add_start_docstrings


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class AutoTrainingArguments(TrainingArguments):
"""
自动并行相关参数配置
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主库里面的代码,建议写英文注释

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
fuse_allreduce_split_to_reducescatter: bool = field(
default=False,
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
)
eliminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)

def __post_init__(self):
super().__post_init__()
assert self.enable_auto_parallel

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.fuse_allreduce_split_to_reducescatter:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")
Loading