@@ -682,10 +682,6 @@ def forward_absorb(
682682 forward_batch : ForwardBatch ,
683683 zero_allocator : BumpAllocator ,
684684 ) -> torch .Tensor :
685- q_len = hidden_states .shape [0 ]
686- q_input = hidden_states .new_empty (
687- q_len , self .num_local_heads , self .kv_lora_rank + self .qk_rope_head_dim
688- )
689685 if self .q_lora_rank is not None :
690686 q = self .q_a_proj (hidden_states )[0 ]
691687 q = self .q_a_layernorm (q )
@@ -729,20 +725,20 @@ def forward_absorb(
729725 )
730726 else :
731727 q_nope_out = torch .bmm (q_nope .transpose (0 , 1 ), self .w_kc )
732- q_input [..., : self .kv_lora_rank ] = q_nope_out .transpose (0 , 1 )
728+
729+ q_nope_out = q_nope_out .transpose (0 , 1 )
733730
734731 latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
735- v_input = latent_cache [..., : self .kv_lora_rank ]
736- v_input = self .kv_a_layernorm (v_input .contiguous ()).unsqueeze (1 )
737- k_input = latent_cache .unsqueeze (1 )
738- k_input [..., : self .kv_lora_rank ] = v_input
739- k_pe = k_input [..., self .kv_lora_rank :]
732+ k_nope = latent_cache [..., : self .kv_lora_rank ]
733+ k_nope = self .kv_a_layernorm (k_nope ).unsqueeze (1 )
734+ k_pe = latent_cache [..., self .kv_lora_rank :].unsqueeze (1 )
740735
741736 q_pe , k_pe = self .rotary_emb (positions , q_pe , k_pe )
742- q_input [..., self .kv_lora_rank :] = q_pe
743- k_input [..., self .kv_lora_rank :] = k_pe
744737
745- attn_output = self .attn_mqa (q_input , k_input , v_input , forward_batch )
738+ q = torch .cat ([q_nope_out , q_pe ], dim = - 1 )
739+ k = torch .cat ([k_nope , k_pe ], dim = - 1 )
740+
741+ attn_output = self .attn_mqa (q , k , k_nope , forward_batch )
746742 attn_output = attn_output .view (- 1 , self .num_local_heads , self .kv_lora_rank )
747743
748744 if self .use_deep_gemm_bmm :
0 commit comments