Skip to content

Commit 160c79d

Browse files
committed
[XPU] llama add xpu support
1 parent 0790824 commit 160c79d

2 files changed

Lines changed: 103 additions & 20 deletions

File tree

llm/run_pretrain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,12 @@ def main():
483483
config.num_attention_heads % config.sep_parallel_degree == 0
484484
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
485485

486+
if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1:
487+
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
488+
489+
LinearConfig.enable_accumulate_steps_opt()
490+
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
491+
486492
print("Final pre-training config:", config)
487493

488494
# Set the dtype for loading model

paddlenlp/transformers/llama/modeling.py

Lines changed: 97 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
174174
return assignment_list
175175

176176

177-
def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
177+
def parallel_matmul(matmul_op, x: Tensor, y: Tensor, tensor_parallel_output=True):
178178
is_fleet_init = True
179179
tensor_parallel_degree = 1
180180
try:
@@ -192,15 +192,15 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
192192
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
193193
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
194194
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
195-
logits = paddle.matmul(input_parallel, y, transpose_y=False)
195+
logits = matmul_op(input_parallel, y, transpose_y=False)
196196

197197
if tensor_parallel_output:
198198
return logits
199199

200200
return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)
201201

202202
else:
203-
logits = paddle.matmul(x, y, transpose_y=False)
203+
logits = matmul_op(x, y, transpose_y=False)
204204
return logits
205205

206206

@@ -413,6 +413,10 @@ def forward(self, hidden_states):
413413
if self.config.use_fused_rms_norm:
414414
if get_env_device() == "npu":
415415
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
416+
elif get_env_device() == "xpu":
417+
import paddle_xpu_nn
418+
419+
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
416420
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
417421

418422
if paddle.in_dynamic_mode():
@@ -582,12 +586,33 @@ def __init__(self, config):
582586

583587
ColumnParallelLinear = MC2ColumnSeqParallelLinear
584588
RowParallelLinear = MC2RowSeqParallelLinear
589+
elif get_env_device() == "xpu":
590+
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
591+
XPUColumnSequenceParallelLinear,
592+
XPURowSequenceParallelLinear,
593+
)
594+
595+
ColumnParallelLinear = XPUColumnSequenceParallelLinear
596+
RowParallelLinear = XPURowSequenceParallelLinear
585597
else:
586598
ColumnParallelLinear = ColumnSequenceParallelLinear
587599
RowParallelLinear = RowSequenceParallelLinear
588600
else:
589-
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
590-
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
601+
if get_env_device() == "xpu":
602+
import paddle_xpu # noqa: F821
603+
604+
ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
605+
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821
606+
else:
607+
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
608+
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
609+
610+
if get_env_device() == "xpu":
611+
import paddle_xpu # noqa: F821
612+
613+
Linear = paddle_xpu.layers.nn.Linear # noqa: F821
614+
else:
615+
Linear = nn.Linear
591616

592617
if config.tensor_parallel_degree > 1:
593618
if config.fuse_attention_ffn:
@@ -619,15 +644,24 @@ def __init__(self, config):
619644
)
620645
else:
621646
if config.fuse_attention_ffn:
622-
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
647+
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
623648
else:
624-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
625-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
649+
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
650+
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
626651

627-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
652+
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
628653

629654
def forward(self, x):
630655
if self.fuse_attention_ffn:
656+
# FIXME(yangjianbang): use paddle's native swiglu
657+
if get_env_device() == "xpu":
658+
import paddle_xpu_nn # noqa: F821
659+
660+
out = self.gate_up_fused_proj(x)
661+
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
662+
out = self.down_proj(out)
663+
return out
664+
631665
x = swiglu(self.gate_up_fused_proj(x))
632666
else:
633667
x = swiglu(self.gate_proj(x), self.up_proj(x))
@@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
689723

690724
self.use_fused_rope = config.use_fused_rope
691725
if self.use_fused_rope and get_env_device() != "npu":
692-
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
726+
if (
727+
"gpu" not in paddle.device.get_device()
728+
or "xpu" not in paddle.device.get_device()
729+
or fused_rotary_position_embedding is None
730+
):
693731
warnings.warn(
694732
"Enable fuse rope in the config, but fuse rope is not available. "
695733
"Will disable fuse rope. Try using latest gpu version of Paddle."
@@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
705743

706744
ColumnParallelLinear = MC2ColumnSeqParallelLinear
707745
RowParallelLinear = MC2RowSeqParallelLinear
746+
elif get_env_device() == "xpu":
747+
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
748+
XPUColumnSequenceParallelLinear,
749+
XPURowSequenceParallelLinear,
750+
)
751+
752+
ColumnParallelLinear = XPUColumnSequenceParallelLinear
753+
RowParallelLinear = XPURowSequenceParallelLinear
708754
else:
709755
ColumnParallelLinear = ColumnSequenceParallelLinear
710756
RowParallelLinear = RowSequenceParallelLinear
711757
else:
712-
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
713-
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
758+
if get_env_device() == "xpu":
759+
import paddle_xpu # noqa: F821
760+
761+
ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
762+
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821
763+
else:
764+
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
765+
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
766+
767+
if get_env_device() == "xpu":
768+
import paddle_xpu # noqa: F821
769+
770+
Linear = paddle_xpu.layers.nn.Linear # noqa: F821
771+
else:
772+
Linear = nn.Linear
714773

715774
if config.tensor_parallel_degree > 1:
716775
if self.fuse_attention_qkv:
@@ -741,36 +800,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
741800
gather_output=False,
742801
)
743802
else:
744-
self.k_proj = nn.Linear(
803+
self.k_proj = Linear(
745804
self.hidden_size,
746805
self.config.num_key_value_heads * self.head_dim,
747806
bias_attr=False,
748807
)
749-
self.v_proj = nn.Linear(
808+
self.v_proj = Linear(
750809
self.hidden_size,
751810
self.config.num_key_value_heads * self.head_dim,
752811
bias_attr=False,
753812
)
754813

755814
else:
756815
if self.fuse_attention_qkv:
757-
self.qkv_proj = nn.Linear(
816+
self.qkv_proj = Linear(
758817
self.hidden_size,
759818
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
760819
bias_attr=False,
761820
)
762821
else:
763-
self.q_proj = nn.Linear(
822+
self.q_proj = Linear(
764823
self.hidden_size,
765824
self.hidden_size,
766825
bias_attr=False,
767826
)
768-
self.k_proj = nn.Linear(
827+
self.k_proj = Linear(
769828
self.hidden_size,
770829
self.config.num_key_value_heads * self.head_dim,
771830
bias_attr=False,
772831
)
773-
self.v_proj = nn.Linear(
832+
self.v_proj = Linear(
774833
self.hidden_size,
775834
self.config.num_key_value_heads * self.head_dim,
776835
bias_attr=False,
@@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
784843
input_is_parallel=True,
785844
)
786845
else:
787-
self.o_proj = nn.Linear(
846+
self.o_proj = Linear(
788847
self.hidden_size,
789848
self.hidden_size,
790849
bias_attr=False,
@@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14281487
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
14291488
expanded_attn_mask = expanded_attn_mask.astype("float16")
14301489
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
1490+
elif get_env_device() == "xpu":
1491+
x = paddle.to_tensor(0.0, dtype=dtype)
1492+
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
1493+
expanded_attn_mask = expanded_attn_mask.astype(dtype)
1494+
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
14311495
else:
14321496
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
14331497
return expanded_attn_mask
@@ -1708,6 +1772,13 @@ def __init__(self, config: LlamaConfig):
17081772
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
17091773
if self.weight.is_distributed:
17101774
self.weight.split_axis = 1
1775+
if paddle.is_compiled_with_xpu():
1776+
from paddle_xpu.layers.nn import xpu_matmul # noqa: F401
1777+
1778+
self._xpu_matmul = xpu_matmul()
1779+
self.matmul_op = self._xpu_matmul.forward
1780+
else:
1781+
self.matmul_op = paddle.matmul
17111782

17121783
def forward(self, hidden_states, tensor_parallel_output=None):
17131784
if self.config.sequence_parallel:
@@ -1721,7 +1792,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
17211792
if tensor_parallel_output is None:
17221793
tensor_parallel_output = self.config.tensor_parallel_output
17231794

1724-
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
1795+
matmul_op = self.matmul_op
1796+
if paddle.is_compiled_with_xpu():
1797+
from functools import partial
1798+
1799+
matmul_op = partial(matmul_op, training=self.training)
1800+
1801+
logits = parallel_matmul(matmul_op, hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
17251802
return logits
17261803

17271804

0 commit comments

Comments
 (0)