@@ -920,13 +920,10 @@ def get_rope(
920920 rotary_emb = RotaryEmbedding (head_size , rotary_dim , max_position , base ,
921921 is_neox_style , dtype )
922922 else :
923- scaling_type = rope_scaling [
924- "type" ] if "type" in rope_scaling else rope_scaling ["rope_type" ]
925- # The correct one should be "longrope" but keep "su" here
926- # for backward compatible
927- if scaling_type not in {"su" , "longrope" }:
928- scaling_factor = rope_scaling .get ("factor" , 1.0 )
923+ scaling_type = rope_scaling ["rope_type" ]
924+
929925 if scaling_type == "llama3" :
926+ scaling_factor = rope_scaling ["factor" ]
930927 low_freq_factor = rope_scaling ["low_freq_factor" ]
931928 high_freq_factor = rope_scaling ["high_freq_factor" ]
932929 original_max_position = rope_scaling [
@@ -937,16 +934,39 @@ def get_rope(
937934 scaling_factor , low_freq_factor ,
938935 high_freq_factor ,
939936 original_max_position )
937+ elif scaling_type == "default" :
938+ if "mrope_section" in rope_scaling :
939+ rotary_emb = MRotaryEmbedding (
940+ head_size ,
941+ rotary_dim ,
942+ max_position ,
943+ base ,
944+ is_neox_style ,
945+ dtype ,
946+ mrope_section = rope_scaling ["mrope_section" ],
947+ )
948+ else :
949+ rotary_emb = RotaryEmbedding (
950+ head_size ,
951+ rotary_dim ,
952+ max_position ,
953+ base ,
954+ is_neox_style ,
955+ dtype ,
956+ )
940957 elif scaling_type == "linear" :
958+ scaling_factor = rope_scaling ["factor" ]
941959 rotary_emb = LinearScalingRotaryEmbedding (head_size , rotary_dim ,
942960 max_position , base ,
943961 is_neox_style ,
944962 scaling_factor , dtype )
945963 elif scaling_type == "dynamic" :
964+ scaling_factor = rope_scaling ["factor" ]
946965 rotary_emb = DynamicNTKScalingRotaryEmbedding (
947966 head_size , rotary_dim , max_position , base , is_neox_style ,
948967 scaling_factor , dtype )
949968 elif scaling_type == "yarn" :
969+ scaling_factor = rope_scaling ["factor" ]
950970 original_max_position = rope_scaling [
951971 "original_max_position_embeddings" ]
952972 extra_kwargs = {
@@ -961,6 +981,7 @@ def get_rope(
961981 scaling_factor , dtype ,
962982 ** extra_kwargs )
963983 elif scaling_type == "deepseek_yarn" :
984+ scaling_factor = rope_scaling ["factor" ]
964985 original_max_position = rope_scaling [
965986 "original_max_position_embeddings" ]
966987 # assert max_position == original_max_position * scaling_factor
@@ -973,9 +994,7 @@ def get_rope(
973994 rotary_emb = DeepseekScalingRotaryEmbedding (
974995 head_size , rotary_dim , original_max_position , base ,
975996 is_neox_style , scaling_factor , dtype , ** extra_kwargs )
976- # The correct one should be "longrope" but keep "su" here
977- # for backward compatible
978- elif scaling_type == "su" or scaling_type == "longrope" :
997+ elif scaling_type == "longrope" :
979998 short_factor = rope_scaling ["short_factor" ]
980999 long_factor = rope_scaling ["long_factor" ]
9811000 original_max_position = rope_scaling [
@@ -989,16 +1008,6 @@ def get_rope(
9891008 head_size , rotary_dim , max_position , original_max_position ,
9901009 base , is_neox_style , dtype , short_factor , long_factor ,
9911010 ** extra_kwargs )
992- elif scaling_type == "mrope" :
993- rotary_emb = MRotaryEmbedding (
994- head_size ,
995- rotary_dim ,
996- max_position ,
997- base ,
998- is_neox_style ,
999- dtype ,
1000- mrope_section = rope_scaling ["mrope_section" ],
1001- )
10021011 else :
10031012 raise ValueError (f"Unknown RoPE scaling type { scaling_type } " )
10041013 _ROPE_DICT [key ] = rotary_emb
0 commit comments