@@ -35,15 +35,20 @@ def __init__(
3535 vllm_config : VllmConfig ,
3636 prefix : str = "" ,
3737 config : Optional [LlamaConfig ] = None ,
38+ layer_idx : int = 0 ,
3839 ) -> None :
3940 super ().__init__ (vllm_config , prefix = prefix , config = config )
4041
4142 config = config or vllm_config .model_config .hf_config
4243 quant_config = self .get_quant_config (vllm_config )
4344
45+ # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
46+ # Subsequent layers use hidden_size (only hidden_states, no embeds)
47+ qkv_input_size = 2 * self .hidden_size if layer_idx == 0 else self .hidden_size
48+
4449 # override qkv
4550 self .self_attn .qkv_proj = QKVParallelLinear (
46- 2 * self . hidden_size ,
51+ qkv_input_size ,
4752 self .self_attn .head_dim ,
4853 self .self_attn .total_num_heads ,
4954 self .self_attn .total_num_kv_heads ,
@@ -53,6 +58,7 @@ def __init__(
5358 )
5459
5560 self .hidden_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
61+ self .layer_idx = layer_idx
5662
5763 if getattr (config , "norm_before_residual" , False ):
5864 self ._residual_norm = self ._norm_before_residual
@@ -91,11 +97,15 @@ def forward(
9197 hidden_states : torch .Tensor ,
9298 residual : Optional [torch .Tensor ],
9399 ) -> tuple [torch .Tensor , torch .Tensor ]:
94- embeds = self .input_layernorm (embeds )
95-
96- hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
100+ if self .layer_idx == 0 :
101+ # First layer: concatenate embeds with hidden_states
102+ embeds = self .input_layernorm (embeds )
103+ hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
104+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
105+ else :
106+ # Subsequent layers: process hidden_states and residuals only
107+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
97108
98- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
99109 # Self Attention
100110 hidden_states = self .self_attn (
101111 positions = positions ,
@@ -134,9 +144,11 @@ def __init__(
134144 [
135145 LlamaDecoderLayer (
136146 current_vllm_config ,
137- prefix = maybe_prefix (prefix , f"layers.{ start_layer_id } " ),
147+ prefix = maybe_prefix (prefix , f"layers.{ layer_idx + start_layer_id } " ),
138148 config = self .config ,
149+ layer_idx = layer_idx ,
139150 )
151+ for layer_idx in range (self .config .num_hidden_layers )
140152 ]
141153 )
142154 if hasattr (self .config , "target_hidden_size" ):
@@ -167,13 +179,13 @@ def forward(
167179 assert hidden_states .shape [- 1 ] == input_embeds .shape [- 1 ]
168180
169181 residual = None
170- hidden_states , residual = self .layers [ 0 ](
171- positions ,
172- input_embeds ,
173- hidden_states ,
174- residual ,
175- )
176-
182+ for layer in self .layers :
183+ hidden_states , residual = layer (
184+ positions = positions ,
185+ embeds = input_embeds ,
186+ hidden_states = hidden_states ,
187+ residual = residual ,
188+ )
177189 hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
178190 return hidden_states , hidden_prenorm
179191
0 commit comments