@@ -1508,7 +1508,8 @@ def get_layers_start_end_indices(
15081508 if (self .hf_text_config .model_type == "deepseek_mtp"
15091509 or self .hf_config .model_type == "mimo_mtp"
15101510 or self .hf_config .model_type == "glm4_moe_mtp"
1511- or self .hf_config .model_type == "ernie_mtp" ):
1511+ or self .hf_config .model_type == "ernie_mtp"
1512+ or self .hf_config .model_type == "qwen3_next_mtp" ):
15121513 total_num_hidden_layers = getattr (self .hf_text_config ,
15131514 "num_nextn_predict_layers" , 0 )
15141515 else :
@@ -1571,15 +1572,28 @@ def get_num_layers_by_block_type(
15711572 if attn_type_list :
15721573 return sum (t == 1 for t in attn_type_list [start :end ])
15731574
1574- if layers_block_type_value is None and attn_type_list is None :
1575+ # Hybrid model Qwen3Next
1576+ layer_types_value = getattr (self .hf_config , "layer_types" , None )
1577+ if layer_types_value is not None :
1578+ if getattr (block_type , "value" , block_type ) == "attention" :
1579+ return sum (t == "full_attention"
1580+ for t in layer_types_value [start :end ])
1581+ elif getattr (block_type , "value" ,
1582+ block_type ) == "linear_attention" :
1583+ return sum (t == "linear_attention"
1584+ for t in layer_types_value [start :end ])
1585+ else :
1586+ return sum (t == getattr (block_type , "value" , block_type )
1587+ for t in layer_types_value [start :end ])
1588+
1589+ if (layers_block_type_value is None and attn_type_list is None
1590+ and layer_types_value is None ):
15751591 raise ValueError (
15761592 "The model is an hybrid without a"
1577- "layers_block_type or an attn_type_list in the hf_config, "
1578- "cannot determine the num of "
1593+ "layers_block_type or an attn_type_list, or a layer_types "
1594+ "in the hf_config, cannot determine the num of "
15791595 f"{ block_type .value } layers" )
15801596
1581- return sum (t == 1 for t in attn_type_list [start :end ])
1582-
15831597 def get_mamba_chunk_size (self ) -> Optional [int ]:
15841598 """
15851599 Returns the mamba chunk size if it exists
@@ -1866,7 +1880,7 @@ def __post_init__(self):
18661880
18671881SpeculativeMethod = Literal ["ngram" , "eagle" , "eagle3" , "medusa" ,
18681882 "mlp_speculator" , "draft_model" , "deepseek_mtp" ,
1869- "ernie_mtp" ]
1883+ "ernie_mtp" , "qwen3_next_mtp" ]
18701884
18711885
18721886@config
@@ -2007,7 +2021,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
20072021 "n_predict" : n_predict ,
20082022 "architectures" : ["ErnieMTPModel" ]
20092023 })
2010- return hf_config
2024+
2025+ if hf_config .model_type == "qwen3_next" :
2026+ hf_config .model_type = "qwen3_next_mtp"
2027+ if hf_config .model_type == "qwen3_next_mtp" :
2028+ n_predict = getattr (hf_config , "num_nextn_predict_layers" , None )
2029+ hf_config .update ({
2030+ "n_predict" : n_predict ,
2031+ "architectures" : ["Qwen3NextMTP" ]
2032+ })
20112033
20122034 return hf_config
20132035
@@ -2028,9 +2050,13 @@ def __post_init__(self):
20282050 (self .target_model_config .hf_text_config .model_type \
20292051 == "deepseek_v3" or
20302052 self .target_model_config .hf_text_config .model_type in
2031- ("mimo" ,"ernie4_5_moe" )):
2053+ ("mimo" ,"ernie4_5_moe" , "qwen3_next" )):
20322054 # use the draft model from the same model:
20332055 self .model = self .target_model_config .model
2056+ # Align the quantization of draft model for cases such as
2057+ # --quantization fp8 with a bf16 checkpoint.
2058+ if not self .quantization :
2059+ self .quantization = self .target_model_config .quantization
20342060 elif self .method in ("ngram" , "[ngram]" ):
20352061 self .model = "ngram"
20362062 else :
@@ -2140,6 +2166,15 @@ def __post_init__(self):
21402166 "one layer. Might need some code changes " \
21412167 "to support multiple layers."
21422168 )
2169+ elif (self .draft_model_config .hf_config .model_type ==
2170+ "qwen3_next_mtp" ):
2171+ self .method = "qwen3_next_mtp"
2172+ if self .num_speculative_tokens > 1 :
2173+ logger .warning (
2174+ "All Qwen3Next MTP models only have " \
2175+ "one layer. Might need some code changes " \
2176+ "to support multiple layers."
2177+ )
21432178 else :
21442179 self .method = "draft_model"
21452180 raise NotImplementedError (
@@ -2355,7 +2390,8 @@ def num_lookahead_slots(self) -> int:
23552390 return self .num_speculative_tokens
23562391
23572392 def use_eagle (self ) -> bool :
2358- return self .method in ("eagle" , "eagle3" , "deepseek_mtp" , "ernie_mtp" )
2393+ return self .method in ("eagle" , "eagle3" , "deepseek_mtp" , "ernie_mtp" ,
2394+ "qwen3_next_mtp" )
23592395
23602396 def __repr__ (self ) -> str :
23612397 method = self .method
0 commit comments