@@ -610,6 +610,119 @@ def forward(
610610 return query .flatten (- 2 ), key .flatten (- 2 )
611611
612612
613+ def yarn_get_mscale (scale : float = 1 , mscale : float = 1 ) -> float :
614+ if scale <= 1 :
615+ return 1.0
616+ return 0.1 * mscale * math .log (scale ) + 1.0
617+
618+
619+ class DeepseekScalingRotaryEmbedding (RotaryEmbedding ):
620+ """RotaryEmbedding extended with YaRN method.
621+
622+ Credits to Peng et al. github.com/jquesnelle/yarn
623+ """
624+
625+ def __init__ (
626+ self ,
627+ head_size : int ,
628+ rotary_dim : int ,
629+ max_position_embeddings : int ,
630+ base : int ,
631+ is_neox_style : bool ,
632+ scaling_factor : float ,
633+ dtype : torch .dtype ,
634+ * ,
635+ extrapolation_factor : float = 1 ,
636+ attn_factor : float = 1 ,
637+ beta_fast : int = 32 ,
638+ beta_slow : int = 1 ,
639+ mscale : float = 1 ,
640+ mscale_all_dim : float = 0 ,
641+ ) -> None :
642+ self .scaling_factor = scaling_factor
643+ self .extrapolation_factor = extrapolation_factor
644+ self .attn_factor = attn_factor
645+ self .beta_fast = beta_fast
646+ self .beta_slow = beta_slow
647+ # Get n-d magnitude scaling corrected for interpolation.
648+ self .mscale = float (
649+ yarn_get_mscale (self .scaling_factor , float (mscale )) /
650+ yarn_get_mscale (self .scaling_factor , float (mscale_all_dim )) *
651+ attn_factor )
652+ super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
653+ is_neox_style , dtype )
654+
655+ def _compute_inv_freq (self , scaling_factor : float ) -> torch .Tensor :
656+ pos_freqs = self .base ** (torch .arange (
657+ 0 , self .rotary_dim , 2 , dtype = torch .float , device = "cuda" ) /
658+ self .rotary_dim )
659+ inv_freq_extrapolation = 1.0 / pos_freqs
660+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs )
661+
662+ low , high = _yarn_find_correction_range (self .beta_fast , self .beta_slow ,
663+ self .rotary_dim , self .base ,
664+ self .max_position_embeddings )
665+ # Get n-d rotational scaling corrected for extrapolation
666+ inv_freq_mask = (1 - _yarn_linear_ramp_mask (
667+ low , high , self .rotary_dim // 2 ,
668+ dtype = torch .float )) * self .extrapolation_factor
669+ inv_freq = inv_freq_interpolation * (
670+ 1 - inv_freq_mask ) + inv_freq_extrapolation * inv_freq_mask
671+ return inv_freq
672+
673+ def _compute_cos_sin_cache (self ) -> torch .Tensor :
674+ inv_freq = self ._compute_inv_freq (self .scaling_factor )
675+ t = torch .arange (self .max_position_embeddings * self .scaling_factor ,
676+ device = "cuda" ,
677+ dtype = torch .float32 )
678+ freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
679+ cos = (freqs .cos () * self .mscale )
680+ sin = (freqs .sin () * self .mscale )
681+ cache = torch .cat ((cos , sin ), dim = - 1 )
682+ print ("Cache shape" , cache .shape )
683+ return cache
684+
685+ def forward (
686+ self ,
687+ positions : torch .Tensor ,
688+ query : torch .Tensor ,
689+ key : torch .Tensor ,
690+ offsets : Optional [torch .Tensor ] = None ,
691+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
692+ """PyTorch-native implementation equivalent to forward()."""
693+ query_rot = query [..., :self .rotary_dim ]
694+ key_rot = key [..., :self .rotary_dim ]
695+ if self .rotary_dim < self .head_size :
696+ query_pass = query [..., self .rotary_dim :]
697+ key_pass = key [..., self .rotary_dim :]
698+
699+ self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (
700+ positions .device )
701+ cos_sin = self .cos_sin_cache [torch .add (positions , offsets )
702+ if offsets is not None else positions ]
703+ cos , sin = cos_sin .chunk (2 , dim = - 1 )
704+ if self .is_neox_style :
705+ # NOTE(woosuk): Here we assume that the positions tensor has the
706+ # shape [batch_size, seq_len].
707+ cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
708+ sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
709+ else :
710+ cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
711+ sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
712+
713+ rotate_fn = _rotate_neox if self .is_neox_style else _rotate_gptj
714+ query_rot = query_rot * cos + rotate_fn (query_rot ) * sin
715+ key_rot = key_rot * cos + rotate_fn (key_rot ) * sin
716+
717+ if self .rotary_dim < self .head_size :
718+ query = torch .cat ((query_rot , query_pass ), dim = - 1 )
719+ key = torch .cat ((key_rot , key_pass ), dim = - 1 )
720+ else :
721+ query = query_rot
722+ key = key_rot
723+ return query , key
724+
725+
613726class GemmaRotaryEmbedding (RotaryEmbedding ):
614727
615728 def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
@@ -679,6 +792,19 @@ def get_rope(
679792 base , is_neox_style ,
680793 scaling_factor , dtype ,
681794 ** extra_kwargs )
795+ elif scaling_type == "deepseek_yarn" :
796+ original_max_position = rope_scaling [
797+ "original_max_position_embeddings" ]
798+ # assert max_position == original_max_position * scaling_factor
799+ extra_kwargs = {
800+ k : v
801+ for k , v in rope_scaling .items ()
802+ if k in ("extrapolation_factor" , "attn_factor" , "beta_fast" ,
803+ "beta_slow" , "mscale" , "mscale_all_dim" )
804+ }
805+ rotary_emb = DeepseekScalingRotaryEmbedding (
806+ head_size , rotary_dim , original_max_position , base ,
807+ is_neox_style , scaling_factor , dtype , ** extra_kwargs )
682808 # The correct one should be "longrope" but keep "su" here
683809 # for backward compatible
684810 elif scaling_type == "su" or scaling_type == "longrope" :
0 commit comments