@@ -174,7 +174,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
174174 return assignment_list
175175
176176
177- def parallel_matmul (x : Tensor , y : Tensor , tensor_parallel_output = True ):
177+ def parallel_matmul (matmul_op , x : Tensor , y : Tensor , tensor_parallel_output = True ):
178178 is_fleet_init = True
179179 tensor_parallel_degree = 1
180180 try :
@@ -192,15 +192,15 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
192192 if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed :
193193 # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
194194 input_parallel = paddle .distributed .collective ._c_identity (x , group = model_parallel_group )
195- logits = paddle . matmul (input_parallel , y , transpose_y = False )
195+ logits = matmul_op (input_parallel , y , transpose_y = False )
196196
197197 if tensor_parallel_output :
198198 return logits
199199
200200 return paddle .distributed .collective ._c_concat (logits , group = model_parallel_group )
201201
202202 else :
203- logits = paddle . matmul (x , y , transpose_y = False )
203+ logits = matmul_op (x , y , transpose_y = False )
204204 return logits
205205
206206
@@ -413,6 +413,10 @@ def forward(self, hidden_states):
413413 if self .config .use_fused_rms_norm :
414414 if get_env_device () == "npu" :
415415 return core .eager ._run_custom_op ("rms_norm_npu" , hidden_states , self .weight , self .variance_epsilon )[0 ]
416+ elif get_env_device () == "xpu" :
417+ import paddle_xpu_nn
418+
419+ return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
416420 return rms_norm_fused (hidden_states , self .weight , self .variance_epsilon )
417421
418422 if paddle .in_dynamic_mode ():
@@ -582,12 +586,33 @@ def __init__(self, config):
582586
583587 ColumnParallelLinear = MC2ColumnSeqParallelLinear
584588 RowParallelLinear = MC2RowSeqParallelLinear
589+ elif get_env_device () == "xpu" :
590+ from paddle_xpu .layers .nn .sequence_parallel import ( # noqa: F401
591+ XPUColumnSequenceParallelLinear ,
592+ XPURowSequenceParallelLinear ,
593+ )
594+
595+ ColumnParallelLinear = XPUColumnSequenceParallelLinear
596+ RowParallelLinear = XPURowSequenceParallelLinear
585597 else :
586598 ColumnParallelLinear = ColumnSequenceParallelLinear
587599 RowParallelLinear = RowSequenceParallelLinear
588600 else :
589- ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
590- RowParallelLinear = fleet .meta_parallel .RowParallelLinear
601+ if get_env_device () == "xpu" :
602+ import paddle_xpu # noqa: F821
603+
604+ ColumnParallelLinear = paddle_xpu .layers .nn .ColumnParallelLinear # noqa: F821
605+ RowParallelLinear = paddle_xpu .layers .nn .RowParallelLinear # noqa: F821
606+ else :
607+ ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
608+ RowParallelLinear = fleet .meta_parallel .RowParallelLinear
609+
610+ if get_env_device () == "xpu" :
611+ import paddle_xpu # noqa: F821
612+
613+ Linear = paddle_xpu .layers .nn .Linear # noqa: F821
614+ else :
615+ Linear = nn .Linear
591616
592617 if config .tensor_parallel_degree > 1 :
593618 if config .fuse_attention_ffn :
@@ -619,15 +644,24 @@ def __init__(self, config):
619644 )
620645 else :
621646 if config .fuse_attention_ffn :
622- self .gate_up_fused_proj = nn . Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
647+ self .gate_up_fused_proj = Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
623648 else :
624- self .gate_proj = nn . Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
625- self .up_proj = nn . Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
649+ self .gate_proj = Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
650+ self .up_proj = Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
626651
627- self .down_proj = nn . Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
652+ self .down_proj = Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
628653
629654 def forward (self , x ):
630655 if self .fuse_attention_ffn :
656+ # FIXME(yangjianbang): use paddle's native swiglu
657+ if get_env_device () == "xpu" :
658+ import paddle_xpu_nn # noqa: F821
659+
660+ out = self .gate_up_fused_proj (x )
661+ out = paddle_xpu_nn .xpu_swiglu (out , axis = - 1 , turn = True )
662+ out = self .down_proj (out )
663+ return out
664+
631665 x = swiglu (self .gate_up_fused_proj (x ))
632666 else :
633667 x = swiglu (self .gate_proj (x ), self .up_proj (x ))
@@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
689723
690724 self .use_fused_rope = config .use_fused_rope
691725 if self .use_fused_rope and get_env_device () != "npu" :
692- if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
726+ if (
727+ "gpu" not in paddle .device .get_device ()
728+ or "xpu" not in paddle .device .get_device ()
729+ or fused_rotary_position_embedding is None
730+ ):
693731 warnings .warn (
694732 "Enable fuse rope in the config, but fuse rope is not available. "
695733 "Will disable fuse rope. Try using latest gpu version of Paddle."
@@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
705743
706744 ColumnParallelLinear = MC2ColumnSeqParallelLinear
707745 RowParallelLinear = MC2RowSeqParallelLinear
746+ elif get_env_device () == "xpu" :
747+ from paddle_xpu .layers .nn .sequence_parallel import ( # noqa: F401
748+ XPUColumnSequenceParallelLinear ,
749+ XPURowSequenceParallelLinear ,
750+ )
751+
752+ ColumnParallelLinear = XPUColumnSequenceParallelLinear
753+ RowParallelLinear = XPURowSequenceParallelLinear
708754 else :
709755 ColumnParallelLinear = ColumnSequenceParallelLinear
710756 RowParallelLinear = RowSequenceParallelLinear
711757 else :
712- ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
713- RowParallelLinear = fleet .meta_parallel .RowParallelLinear
758+ if get_env_device () == "xpu" :
759+ import paddle_xpu # noqa: F821
760+
761+ ColumnParallelLinear = paddle_xpu .layers .nn .ColumnParallelLinear # noqa: F821
762+ RowParallelLinear = paddle_xpu .layers .nn .RowParallelLinear # noqa: F821
763+ else :
764+ ColumnParallelLinear = fleet .meta_parallel .ColumnParallelLinear
765+ RowParallelLinear = fleet .meta_parallel .RowParallelLinear
766+
767+ if get_env_device () == "xpu" :
768+ import paddle_xpu # noqa: F821
769+
770+ Linear = paddle_xpu .layers .nn .Linear # noqa: F821
771+ else :
772+ Linear = nn .Linear
714773
715774 if config .tensor_parallel_degree > 1 :
716775 if self .fuse_attention_qkv :
@@ -741,36 +800,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
741800 gather_output = False ,
742801 )
743802 else :
744- self .k_proj = nn . Linear (
803+ self .k_proj = Linear (
745804 self .hidden_size ,
746805 self .config .num_key_value_heads * self .head_dim ,
747806 bias_attr = False ,
748807 )
749- self .v_proj = nn . Linear (
808+ self .v_proj = Linear (
750809 self .hidden_size ,
751810 self .config .num_key_value_heads * self .head_dim ,
752811 bias_attr = False ,
753812 )
754813
755814 else :
756815 if self .fuse_attention_qkv :
757- self .qkv_proj = nn . Linear (
816+ self .qkv_proj = Linear (
758817 self .hidden_size ,
759818 self .hidden_size + 2 * self .config .num_key_value_heads * self .head_dim ,
760819 bias_attr = False ,
761820 )
762821 else :
763- self .q_proj = nn . Linear (
822+ self .q_proj = Linear (
764823 self .hidden_size ,
765824 self .hidden_size ,
766825 bias_attr = False ,
767826 )
768- self .k_proj = nn . Linear (
827+ self .k_proj = Linear (
769828 self .hidden_size ,
770829 self .config .num_key_value_heads * self .head_dim ,
771830 bias_attr = False ,
772831 )
773- self .v_proj = nn . Linear (
832+ self .v_proj = Linear (
774833 self .hidden_size ,
775834 self .config .num_key_value_heads * self .head_dim ,
776835 bias_attr = False ,
@@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
784843 input_is_parallel = True ,
785844 )
786845 else :
787- self .o_proj = nn . Linear (
846+ self .o_proj = Linear (
788847 self .hidden_size ,
789848 self .hidden_size ,
790849 bias_attr = False ,
@@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14281487 y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float16" )
14291488 expanded_attn_mask = expanded_attn_mask .astype ("float16" )
14301489 expanded_attn_mask = paddle .where (expanded_attn_mask , x , y ).astype (dtype )
1490+ elif get_env_device () == "xpu" :
1491+ x = paddle .to_tensor (0.0 , dtype = dtype )
1492+ y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = dtype )
1493+ expanded_attn_mask = expanded_attn_mask .astype (dtype )
1494+ expanded_attn_mask = paddle .where (expanded_attn_mask , x , y ).astype (dtype )
14311495 else :
14321496 expanded_attn_mask = paddle .where (expanded_attn_mask , 0.0 , paddle .finfo (dtype ).min ).astype (dtype )
14331497 return expanded_attn_mask
@@ -1708,6 +1772,13 @@ def __init__(self, config: LlamaConfig):
17081772 self .weight .is_distributed = True if (vocab_size != config .vocab_size ) else False
17091773 if self .weight .is_distributed :
17101774 self .weight .split_axis = 1
1775+ if paddle .is_compiled_with_xpu ():
1776+ from paddle_xpu .layers .nn import xpu_matmul # noqa: F401
1777+
1778+ self ._xpu_matmul = xpu_matmul ()
1779+ self .matmul_op = self ._xpu_matmul .forward
1780+ else :
1781+ self .matmul_op = paddle .matmul
17111782
17121783 def forward (self , hidden_states , tensor_parallel_output = None ):
17131784 if self .config .sequence_parallel :
@@ -1721,7 +1792,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
17211792 if tensor_parallel_output is None :
17221793 tensor_parallel_output = self .config .tensor_parallel_output
17231794
1724- logits = parallel_matmul (hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
1795+ matmul_op = self .matmul_op
1796+ if paddle .is_compiled_with_xpu ():
1797+ from functools import partial
1798+
1799+ matmul_op = partial (matmul_op , training = self .training )
1800+
1801+ logits = parallel_matmul (matmul_op , hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
17251802 return logits
17261803
17271804
0 commit comments