@@ -232,8 +232,7 @@ def __init__(
232232 quant_config = quant_config ,
233233 prefix = f"{ prefix } .o_proj" )
234234
235- self .position_embedding_type = config .position_embedding_type
236- if self .position_embedding_type == "rope" :
235+ if config .position_embedding_type == "rope" :
237236 self .rotary_emb = get_rope (
238237 self .head_dim ,
239238 rotary_dim = self .head_dim ,
@@ -244,6 +243,8 @@ def __init__(
244243 and config .rope_scaling is not None else None ,
245244 is_neox_style = True ,
246245 )
246+ else :
247+ self .rotary_emb = None
247248
248249 self .attn = Attention (self .num_heads ,
249250 self .head_dim ,
@@ -263,7 +264,7 @@ def forward(
263264 key = self .k_proj (hidden_states )[0 ]
264265 value = self .v_proj (hidden_states )[0 ]
265266
266- if self .position_embedding_type == "rope" :
267+ if self .rotary_emb is not None :
267268 query , key = self .rotary_emb (positions , query , key )
268269
269270 hidden_states = self .attn (query , key , value )
@@ -349,11 +350,11 @@ def forward(
349350 hidden_states = hidden_states * self .embedding_multiplier
350351 residual = None
351352 else :
352- assert intermediate_tensors is not None
353+ if intermediate_tensors is None :
354+ raise RuntimeError ('Intermediate tensors may not be None!' )
353355 hidden_states = intermediate_tensors ["hidden_states" ]
354356 residual = intermediate_tensors ["residual" ]
355357
356- residual = None
357358 num_attn = 0
358359 for i in range (len (self .layers )):
359360 layer = self .layers [i ]
@@ -463,18 +464,19 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
463464 embedding_padding_modules = ["lm_head" ]
464465
465466 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
467+ super ().__init__ ()
468+
466469 config = vllm_config .model_config .hf_config
467470 self .vllm_config = vllm_config
468471 self .model_config = vllm_config .model_config
469472 cache_config = vllm_config .cache_config
470473 lora_config = vllm_config .lora_config
471474 scheduler_config = vllm_config .scheduler_config
472- assert not cache_config .enable_prefix_caching , \
473- "GraniteMoeHybrid currently does not support prefix caching"
475+ if cache_config .enable_prefix_caching :
476+ raise RuntimeError (
477+ "GraniteMoeHybrid currently does not support prefix caching" )
474478
475479 self .quant_config = vllm_config .quant_config
476-
477- super ().__init__ ()
478480 self .config = config
479481 self .scheduler_config = scheduler_config
480482 self .model = GraniteMoeHybridModel (vllm_config = vllm_config ,
0 commit comments