@@ -317,7 +317,7 @@ def forward(
317317 # normal attention
318318 # When block_tables are not filled, it means q and k are the
319319 # prompt, and they have the same length.
320- out = flash_attn_varlen_func (
320+ flash_attn_varlen_func (
321321 q = query ,
322322 k = key ,
323323 v = value ,
@@ -329,14 +329,13 @@ def forward(
329329 causal = True ,
330330 window_size = self .sliding_window ,
331331 alibi_slopes = self .alibi_slopes ,
332+ out = output [:num_prefill_tokens ],
332333 )
333- assert output [:num_prefill_tokens ].shape == out .shape
334- output [:num_prefill_tokens ] = out
335334 else :
336335 # prefix-enabled attention
337336 assert prefill_meta .seq_lens is not None
338337 max_seq_len = max (prefill_meta .seq_lens )
339- output [: num_prefill_tokens ] = flash_attn_varlen_func (
338+ flash_attn_varlen_func (
340339 q = query ,
341340 k = key_cache ,
342341 v = value_cache ,
@@ -348,11 +347,12 @@ def forward(
348347 causal = True ,
349348 alibi_slopes = self .alibi_slopes ,
350349 block_table = prefill_meta .block_tables ,
350+ out = output [:num_prefill_tokens ],
351351 )
352352
353353 if decode_meta := attn_metadata .decode_metadata :
354354 # Decoding run.
355- output [ num_prefill_tokens :] = flash_attn_with_kvcache (
355+ flash_attn_with_kvcache (
356356 decode_query .unsqueeze (1 ),
357357 key_cache ,
358358 value_cache ,
@@ -361,7 +361,8 @@ def forward(
361361 softmax_scale = self .scale ,
362362 causal = True ,
363363 alibi_slopes = self .alibi_slopes ,
364- ).squeeze (1 )
364+ out = output [num_prefill_tokens :].unsqueeze (1 ),
365+ )
365366
366367 # Reshape the output tensor.
367368 return output .view (num_tokens , hidden_size )
0 commit comments