Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -483,6 +484,15 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是做什么的?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU针对accumulate_steps > 1的场景进行优化,配合下面的paddle_xpu里面的Linear层进行使用


print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
148 changes: 131 additions & 17 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,13 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
pass
Comment thread
dynamicheart marked this conversation as resolved.
Outdated
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -582,12 +589,49 @@ def __init__(self, config):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
except ImportError:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
Comment thread
dynamicheart marked this conversation as resolved.
Outdated
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
except ImportError:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
Comment thread
dynamicheart marked this conversation as resolved.
Outdated
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Linear = XPULinear
except ImportError:
Linear = nn.Linear
Comment thread
dynamicheart marked this conversation as resolved.
Outdated
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
Expand Down Expand Up @@ -619,15 +663,29 @@ def __init__(self, config):
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out
except ImportError:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
return out

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -688,7 +746,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand All @@ -705,12 +763,49 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
except ImportError:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
except ImportError:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Linear = XPULinear
except:
Linear = nn.Linear
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -741,36 +836,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -784,7 +879,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1428,6 +1523,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当传入的xy是整型scalar类型时,paddle.where 会将其视为int64、形状[1]的tensor,并会进行broadcast_add操作,详见search.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面 npu 的逻辑看着差不多,可以复用吗?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以复用的,但是npu里面写死了dtype是float16,xpu跑的程序是可能是float16,也可能是bfloat16的。我们需要修改npu的模块么?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SylarTiaNII 看一下?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支

else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1708,6 +1808,15 @@ def __init__(self, config: LlamaConfig):
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()
except ImportError:
self.xpu_parallel_matmul = None

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1721,7 +1830,12 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
Comment on lines +1742 to +1743
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面有两个原因:

  • XPU的一个优化是需要将parallel_matmul作为一个对象来存储某些状态
  • XPU需要training信息来进行优化

)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down