@@ -463,7 +463,12 @@ class QwenVLOnlineEagle3Model(Eagle3Model):
463463 """
464464
465465 def __init__ (
466- self , target_model , draft_model : Eagle3DraftModel , processor , length : int = 7
466+ self ,
467+ target_model ,
468+ draft_model : Eagle3DraftModel ,
469+ processor ,
470+ length : int = 7 ,
471+ attention_backend : str = "sdpa" ,
467472 ):
468473 """
469474 Args:
@@ -476,6 +481,7 @@ def __init__(
476481 self .draft_model = draft_model
477482 self .processor = processor
478483 self .length = length
484+ self .attention_backend = attention_backend
479485
480486 @torch .no_grad ()
481487 def _prepare_data (
@@ -605,11 +611,20 @@ def forward(
605611 pixel_values: batch image pixel values, used for VLM models
606612 image_grid_thw: (batch, 3), image grid thw, used for VLM models
607613 """
608- # Step 1 : prepare data with the target model
614+ # Step 0 : prepare data with the target model
609615 hidden_states , target , loss_mask , input_ids = self ._prepare_data (
610616 input_ids , attention_mask , loss_mask , pixel_values , image_grid_thw
611617 )
612618
619+ # Step 1: handle vocab size
620+ target_p_padded , position_mask = _compute_target_p_padded (
621+ target = target ,
622+ t2d = self .draft_model .t2d ,
623+ loss_mask = loss_mask ,
624+ length = self .length ,
625+ )
626+ del target
627+
613628 # basic info
614629 batch_size , seq_length , _ = hidden_states .shape
615630 seq_length_with_past = seq_length
@@ -656,21 +671,28 @@ def forward(
656671 dtype = torch .bool ,
657672 device = hidden_states .device ,
658673 )
659- attention_mask = self .draft_model .prepare_decoder_attention_mask (
660- attention_mask = attention_mask ,
661- hidden_states = hidden_states ,
662- batch_size = batch_size ,
663- seq_length = seq_length ,
664- past_key_values_length = past_key_values_length ,
665- )
674+ if self .attention_backend == "sdpa" :
675+ attention_mask = self .draft_model .prepare_decoder_attention_mask (
676+ attention_mask = attention_mask ,
677+ hidden_states = hidden_states ,
678+ batch_size = batch_size ,
679+ seq_length = seq_length ,
680+ past_key_values_length = past_key_values_length ,
681+ )
666682
667683 # Step 5: run TTT
668684 plosses = []
669685 vlosses = []
670686 acces = []
671- cache_hidden = [[], []]
687+ if self .attention_backend == "sdpa" :
688+ cache_hidden = [[], []]
689+ past_key_values = None
690+ elif self .attention_backend == "flex_attention" :
691+ cache_hidden = None
692+ past_key_values = DynamicCache ()
672693
673694 for idx in range (self .length ):
695+ target_p = target_p_padded [:, idx : idx + seq_length , :].contiguous ()
674696 is_last = idx == self .length - 1
675697
676698 # Step 5.1: embed the input ids
@@ -685,55 +707,44 @@ def forward(
685707 cache_hidden = cache_hidden ,
686708 attention_mask = attention_mask ,
687709 position_ids = position_ids ,
710+ past_key_values = past_key_values ,
688711 use_cache = True ,
689712 )
690713
691- # Step 5.3: handle vocab size
692- with torch .no_grad ():
693- target_head = target
694- target_max_token = target_head .argmax (- 1 )
695- target_mask = self .draft_model .t2d [target_max_token ]
696- target_mask = target_mask [..., None ].int ()
697- position_mask = target_mask * loss_mask
698- target_head = target_head [..., self .draft_model .t2d ]
699- target_head = target_head .float ()
700- target_p = nn .Softmax (dim = 2 )(target_head )
701- target_p = target_p .detach ()
702-
703714 # update hidden states for next step
704715 hidden_states = hidden_states_out
705716
706717 # Step 5.4: get logits
707718 logits = self .draft_model .compute_logits (hidden_states )
708- logits = logits .float ()
709-
710- # Step 5.5: calculate loss
711- out_logp = nn .LogSoftmax (dim = 2 )(logits )
712- plogp = target_p * out_logp
713- loss = - torch .sum (position_mask * plogp , 2 ).mean ()
714719
715- # Step 5.6: record metrics
716- plosses .append (loss )
720+ # Step 5.5: record metrics first as we in-place modify logits
717721 with torch .no_grad ():
718722 acces .append (
719- (
720- (logits .argmax (- 1 ) == target_p .argmax (- 1 ))
721- * position_mask .squeeze (- 1 )
723+ _compute_metric_acc (
724+ logits = logits ,
725+ target_p = target_p ,
726+ position_mask = position_mask ,
727+ loss_mask = loss_mask ,
722728 )
723- .sum ()
724- .item ()
725- / (loss_mask .sum ().item () + 1e-6 )
726729 )
727730
731+ # Step 5.6: calculate loss, in-place modifies logits!
732+ loss = LogSoftmaxLoss .apply (logits , target_p , position_mask )
733+ plosses .append (loss )
734+
728735 if not is_last :
729736 # Step 5.7: we need to update the loss mask
730737 input_ids = padding (input_ids , left = False )
731- target = padding (target , left = False )
738+ position_mask = padding (position_mask , left = False )
732739 loss_mask = padding (loss_mask , left = False )
733- ind = torch .arange (seq_length , device = attention_mask .device )
734- ind0 = ind [idx :]
735- ind1 = ind [: seq_length - idx ]
736- attention_mask [:, :, ind0 , ind1 ] = torch .finfo (attention_mask .dtype ).min
740+ if self .attention_backend == "sdpa" :
741+ ind = torch .arange (seq_length , device = attention_mask .device )
742+ ind0 = ind [idx :]
743+ ind1 = ind [: seq_length - idx ]
744+ attention_mask [:, :, ind0 , ind1 ] = torch .finfo (
745+ attention_mask .dtype
746+ ).min
747+ # Flex attention mask shirnking is handled inside attention module
737748 return plosses , vlosses , acces
738749
739750
@@ -775,4 +786,4 @@ def _compute_target_p(target, t2d, loss_mask):
775786def _compute_metric_acc (logits , target_p , position_mask , loss_mask ):
776787 return (
777788 (logits .argmax (- 1 ) == target_p .argmax (- 1 )) * position_mask .squeeze (- 1 )
778- ).sum (). item () / ( loss_mask .sum ().item () + 1e-6 )
789+ ).sum () / loss_mask .sum ().clamp_min ( 1e-6 )
0 commit comments