Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -143,6 +144,11 @@ def main():
)

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
training_args.sequence_parallel,
model_args.lora,
)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
5 changes: 5 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
Expand Down Expand Up @@ -413,6 +414,10 @@ def main():
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# set all llm config
LlmMetaConfig.set_llm_config(config, training_args)
config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
training_args.sequence_parallel,
)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的配置信息会传到下游任务里面吗?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 _set_unsavable_keys 吗?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):

# refined_recompute attributes
(
"refined_recompute",
str,
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
]

@classmethod
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def swiglu(x, y=None):
except:
flash_attention = None

from paddlenlp.transformers.refined_recompute import no_recompute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要叫no_recompute,感觉怪怪的

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要么改成skip_recompute也行

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompute(func, xxxxx) vs no_recompute(func, xxxxxx)

from paddlenlp.transformers.ring_flash_attention import RingFlashAttention


Expand Down Expand Up @@ -174,6 +175,7 @@ def fusion_flash_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand Down Expand Up @@ -257,28 +259,34 @@ def fusion_flash_attention(
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)

if hasattr(F, "flashmask_attention"):
attn_output = F.flashmask_attention(
attn_output = no_recompute(
F.flashmask_attention,
query_states,
key_states,
value_states,
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
causal=True,
enable=skip_recompute,
)
else:
attn_output = F.flash_attention_with_sparse_mask(
attn_output = no_recompute(
F.flash_attention_with_sparse_mask,
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
enable=skip_recompute,
)
else:
attn_output = F.scaled_dot_product_attention(
attn_output = no_recompute(
F.scaled_dot_product_attention,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=query_states.shape[1] != 1,
enable=skip_recompute,
)
attn_weights = None

Expand Down
32 changes: 30 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.refined_recompute import (
RRColumnSequenceParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
Expand Down Expand Up @@ -215,6 +221,7 @@ def scaled_dot_product_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand All @@ -232,6 +239,7 @@ def scaled_dot_product_attention(
sequence_parallel,
reshard_layer,
npu_is_casual,
skip_recompute=skip_recompute,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
Expand Down Expand Up @@ -604,6 +612,12 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

if config.recompute:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
Expand Down Expand Up @@ -718,6 +732,12 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

if config.recompute:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
Expand Down Expand Up @@ -820,6 +840,9 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

self.attn_func = scaled_dot_product_attention

if config.recompute and config.skip_recompute_ops.get("flash_attn", False):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

def _init_rope(self):
if (
hasattr(self.config, "rope_scaling")
Expand Down Expand Up @@ -1470,7 +1493,12 @@ def __init__(self, config: LlamaConfig):
)

self.layers = nn.LayerList(
[LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)]
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)

Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
)
from paddlenlp.utils.tools import get_env_device

from .modeling import (
Expand Down Expand Up @@ -371,7 +374,11 @@ def get_hcg():

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
LayerDesc(
LlamaDecoderLayerPipe,
config=create_skip_config_for_refined_recompute(i, config),
layerwise_recompute=i not in self.no_recompute_layers,
),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
Expand Down
Loading