@@ -270,38 +270,47 @@ def __init__(
270270 ) -> None :
271271 super ().__init__ ()
272272 self .config = config
273+ self .cache_config = cache_config
274+ self .quant_config = quant_config
273275 self .hidden_size = config .hidden_size
274- rope_theta = getattr (config , "rope_theta" , 10000 )
275- rope_scaling = getattr (config , "rope_scaling" , None )
276- max_position_embeddings = getattr (config , "max_position_embeddings" ,
277- 8192 )
276+ self .rope_theta = getattr (config , "rope_theta" , 10000 )
277+ self .rope_scaling = getattr (config , "rope_scaling" , None )
278+ self .max_position_embeddings = getattr (config ,
279+ "max_position_embeddings" , 8192 )
280+ self ._init_attn_block ()
281+ self ._init_ffn_block ()
282+
283+ def _init_attn_block (self ):
284+ self .input_layernorm = RMSNorm (self .config .hidden_size ,
285+ eps = self .config .rms_norm_eps )
278286 self .self_attn = MiniCPMAttention (
279287 hidden_size = self .hidden_size ,
280- num_heads = config .num_attention_heads ,
281- num_kv_heads = config .num_key_value_heads ,
282- rope_theta = rope_theta ,
283- rope_scaling = rope_scaling ,
284- max_position_embeddings = max_position_embeddings ,
285- cache_config = cache_config ,
286- quant_config = quant_config ,
288+ num_heads = self . config .num_attention_heads ,
289+ num_kv_heads = self . config .num_key_value_heads ,
290+ rope_theta = self . rope_theta ,
291+ rope_scaling = self . rope_scaling ,
292+ max_position_embeddings = self . max_position_embeddings ,
293+ cache_config = self . cache_config ,
294+ quant_config = self . quant_config ,
287295 )
296+
297+ def _init_ffn_block (self ):
298+ self .post_attention_layernorm = RMSNorm (self .config .hidden_size ,
299+ eps = self .config .rms_norm_eps )
288300 self .num_experts = getattr (self .config , "num_experts" , 0 )
289301 if self .num_experts == 0 :
290302 self .mlp = MiniCPMMLP (
291303 hidden_size = self .hidden_size ,
292- intermediate_size = config .intermediate_size ,
293- hidden_act = config .hidden_act ,
294- quant_config = quant_config ,
304+ intermediate_size = self . config .intermediate_size ,
305+ hidden_act = self . config .hidden_act ,
306+ quant_config = self . quant_config ,
295307 )
296308 else :
297- self .mlp = MiniCPMMoE (num_experts = config .num_experts ,
298- top_k = config .num_experts_per_tok ,
299- hidden_size = config .hidden_size ,
300- intermediate_size = config .intermediate_size )
301- self .input_layernorm = RMSNorm (config .hidden_size ,
302- eps = config .rms_norm_eps )
303- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
304- eps = config .rms_norm_eps )
309+ self .mlp = MiniCPMMoE (
310+ num_experts = self .config .num_experts ,
311+ top_k = self .config .num_experts_per_tok ,
312+ hidden_size = self .config .hidden_size ,
313+ intermediate_size = self .config .intermediate_size )
305314
306315 def forward (
307316 self ,
@@ -344,6 +353,8 @@ def __init__(
344353 ) -> None :
345354 super ().__init__ ()
346355 self .config = config
356+ self .cache_config = cache_config
357+ self .quant_config = quant_config
347358 self .padding_idx = config .pad_token_id
348359 lora_vocab = (lora_config .lora_extra_vocab_size *
349360 (lora_config .max_loras or 1 )) if lora_config else 0
@@ -354,11 +365,15 @@ def __init__(
354365 config .hidden_size ,
355366 org_num_embeddings = config .vocab_size ,
356367 )
368+ self ._init_layers ()
369+ self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
370+
371+ def _init_layers (self ):
357372 self .layers = nn .ModuleList ([
358- MiniCPMDecoderLayer (config , cache_config , quant_config )
359- for _ in range (config .num_hidden_layers )
373+ MiniCPMDecoderLayer (self .config , self .cache_config ,
374+ self .quant_config )
375+ for _ in range (self .config .num_hidden_layers )
360376 ])
361- self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
362377
363378 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
364379 embedding = self .embed_tokens (input_ids )
@@ -431,13 +446,11 @@ def __init__(
431446
432447 self .config = config
433448 self .lora_config = lora_config
449+ self .cache_config = cache_config
450+ self .quant_config = quant_config
434451
435452 self .num_experts = getattr (self .config , "num_experts" , 0 )
436- self .quant_config = quant_config
437- self .model = MiniCPMModel (config ,
438- cache_config ,
439- quant_config ,
440- lora_config = lora_config )
453+ self ._init_model ()
441454 unpadded_vocab_size = config .vocab_size
442455 if lora_config :
443456 unpadded_vocab_size += lora_config .lora_extra_vocab_size
@@ -458,6 +471,12 @@ def __init__(
458471 config .vocab_size )
459472 self .sampler = Sampler ()
460473
474+ def _init_model (self ):
475+ self .model = MiniCPMModel (config = self .config ,
476+ cache_config = self .cache_config ,
477+ quant_config = self .quant_config ,
478+ lora_config = self .lora_config )
479+
461480 def forward (
462481 self ,
463482 input_ids : torch .Tensor ,
0 commit comments