3030from vllm .forward_context import ForwardContext , get_forward_context
3131from vllm .logger import init_logger
3232from vllm .model_executor .layers .fla .ops import (
33- RMSNormGated ,
3433 chunk_gated_delta_rule ,
3534 fused_recurrent_gated_delta_rule ,
3635)
3736from vllm .model_executor .layers .fused_moe import SharedFusedMoE
38- from vllm .model_executor .layers .layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
37+ from vllm .model_executor .layers .layernorm import (
38+ GemmaRMSNorm as Qwen3NextRMSNorm ,
39+ )
40+ from vllm .model_executor .layers .layernorm import RMSNormGated
3941from vllm .model_executor .layers .linear import (
4042 ColumnParallelLinear ,
4143 QKVParallelLinear ,
@@ -436,17 +438,66 @@ def forward(
436438 hidden_states : torch .Tensor ,
437439 output : torch .Tensor ,
438440 ):
439- return torch .ops .vllm .gdn_attention (
440- hidden_states ,
441- output ,
441+ """
442+ Forward pass with three parts:
443+ 1. Input projection
444+ 2. Core attention (custom op)
445+ 3. Output projection
446+ """
447+ num_tokens = hidden_states .size (0 )
448+
449+ # ============================================================
450+ # Part 1: Input Projection
451+ # ============================================================
452+ projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states )
453+ projected_states_ba , _ = self .in_proj_ba (hidden_states )
454+ query , key , value , z , b , a = self .fix_query_key_value_ordering (
455+ projected_states_qkvz , projected_states_ba
456+ )
457+ query , key , value = map (
458+ lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
459+ )
460+ mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
461+
462+ # ============================================================
463+ # Part 2: Core Attention (Custom Op)
464+ # ============================================================
465+ core_attn_out = torch .zeros (
466+ (num_tokens , self .num_v_heads // self .tp_size , self .head_v_dim ),
467+ dtype = hidden_states .dtype ,
468+ device = hidden_states .device ,
469+ )
470+
471+ torch .ops .vllm .gdn_attention_core (
472+ mixed_qkv ,
473+ b ,
474+ a ,
475+ core_attn_out ,
442476 self .prefix ,
443477 )
444478
445- def _forward (
479+ # ============================================================
480+ # Part 3: Output Projection
481+ # ============================================================
482+ z_shape_og = z .shape
483+ # Reshape input data into 2D tensor
484+ core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
485+ z = z .reshape (- 1 , z .shape [- 1 ])
486+ core_attn_out = self .norm (core_attn_out , z )
487+ core_attn_out = core_attn_out .reshape (z_shape_og )
488+ core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
489+ output [:num_tokens ], _ = self .out_proj (core_attn_out )
490+
491+ def _forward_core (
446492 self ,
447- hidden_states : torch .Tensor ,
448- output : torch .Tensor ,
493+ mixed_qkv : torch .Tensor ,
494+ b : torch .Tensor ,
495+ a : torch .Tensor ,
496+ core_attn_out : torch .Tensor ,
449497 ):
498+ """
499+ Core attention computation (called by custom op).
500+ """
450501 forward_context = get_forward_context ()
451502 attn_metadata : AttentionMetadata = forward_context .attn_metadata
452503
@@ -471,18 +522,11 @@ def _forward(
471522 num_actual_tokens = attn_metadata .num_actual_tokens
472523 num_accepted_tokens = attn_metadata .num_accepted_tokens
473524
474- # 1. Set up dimensions for reshapes later
475- projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states [:num_actual_tokens ])
476- projected_states_ba , _ = self .in_proj_ba (hidden_states [:num_actual_tokens ])
477- query , key , value , z , b , a = self .fix_query_key_value_ordering (
478- projected_states_qkvz , projected_states_ba
479- )
480- query , key , value = map (
481- lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
482- )
483- mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
525+ mixed_qkv = mixed_qkv [:num_actual_tokens ]
526+ b = b [:num_actual_tokens ]
527+ a = a [:num_actual_tokens ]
484528
485- # 2 . Convolution sequence transformation
529+ # 1 . Convolution sequence transformation
486530 conv_weights = self .conv1d .weight .view (
487531 self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 )
488532 )
@@ -498,7 +542,7 @@ def _forward(
498542 mixed_qkv_spec = None
499543 mixed_qkv_non_spec = mixed_qkv
500544
501- # 2 .1: process the mutli -query part
545+ # 1 .1: Process the multi -query part
502546 if spec_sequence_masks is not None :
503547 mixed_qkv_spec = causal_conv1d_update (
504548 mixed_qkv_spec ,
@@ -515,7 +559,7 @@ def _forward(
515559 validate_data = False ,
516560 )
517561
518- # 2 .2: process the remaining part
562+ # 1 .2: Process the remaining part
519563 if attn_metadata .num_prefills > 0 :
520564 mixed_qkv_non_spec_T = mixed_qkv_non_spec .transpose (0 , 1 )
521565 # - "cache_indices" updates the conv_state cache in positions
@@ -570,9 +614,9 @@ def _forward(
570614 g_non_spec = g
571615 beta_non_spec = beta
572616
573- # 3 . Recurrent attention
617+ # 2 . Recurrent attention
574618
575- # 3 .1: process the mutlti -query part
619+ # 2 .1: Process the multi -query part
576620 if spec_sequence_masks is not None :
577621 core_attn_out_spec , last_recurrent_state = fused_recurrent_gated_delta_rule (
578622 q = query_spec ,
@@ -590,7 +634,7 @@ def _forward(
590634 else :
591635 core_attn_out_spec , last_recurrent_state = None , None
592636
593- # 3 .2: process the remaining part
637+ # 2 .2: Process the remaining part
594638 if attn_metadata .num_prefills > 0 :
595639 initial_state = ssm_state [non_spec_state_indices_tensor ].contiguous ()
596640 initial_state [~ has_initial_state , ...] = 0
@@ -633,30 +677,20 @@ def _forward(
633677 else :
634678 core_attn_out_non_spec , last_recurrent_state = None , None
635679
636- # Merge core attention output
680+ # 3. Merge core attention output
637681 if spec_sequence_masks is not None and core_attn_out_non_spec is not None :
638- core_attn_out = torch .empty (
682+ merged_out = torch .empty (
639683 (1 , num_actual_tokens , * core_attn_out_spec .shape [2 :]),
640684 dtype = core_attn_out_non_spec .dtype ,
641685 device = core_attn_out_non_spec .device ,
642686 )
643- core_attn_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
644- core_attn_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
645-
687+ merged_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
688+ merged_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
689+ core_attn_out [: num_actual_tokens ] = merged_out . squeeze ( 0 )
646690 elif spec_sequence_masks is not None :
647- core_attn_out = core_attn_out_spec
691+ core_attn_out [: num_actual_tokens ] = core_attn_out_spec . squeeze ( 0 )
648692 else :
649- core_attn_out = core_attn_out_non_spec
650-
651- z_shape_og = z .shape
652- # reshape input data into 2D tensor
653- core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
654- z = z .reshape (- 1 , z .shape [- 1 ])
655- core_attn_out = self .norm (core_attn_out , z )
656- core_attn_out = core_attn_out .reshape (z_shape_og )
657- core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
658-
659- output [:num_actual_tokens ], _ = self .out_proj (core_attn_out )
693+ core_attn_out [:num_actual_tokens ] = core_attn_out_non_spec .squeeze (0 )
660694
661695
662696class Qwen3NextAttention (nn .Module ):
@@ -1260,29 +1294,44 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
12601294 return self .model .get_expert_mapping ()
12611295
12621296
1263- def gdn_attention (
1264- hidden_states : torch .Tensor ,
1265- output : torch .Tensor ,
1297+ def gdn_attention_core (
1298+ mixed_qkv : torch .Tensor ,
1299+ b : torch .Tensor ,
1300+ a : torch .Tensor ,
1301+ core_attn_out : torch .Tensor ,
12661302 layer_name : str ,
12671303) -> None :
1304+ """
1305+ Custom op for the core attention computation.
1306+ Only handles the convolution + recurrent attention part.
1307+ Input/output projections are handled outside this op.
1308+ """
12681309 forward_context : ForwardContext = get_forward_context ()
12691310 self = forward_context .no_compile_layers [layer_name ]
1270- self ._forward (hidden_states = hidden_states , output = output )
1311+ self ._forward_core (
1312+ mixed_qkv = mixed_qkv ,
1313+ b = b ,
1314+ a = a ,
1315+ core_attn_out = core_attn_out ,
1316+ )
12711317
12721318
1273- def gdn_attention_fake (
1274- hidden_states : torch .Tensor ,
1275- output : torch .Tensor ,
1319+ def gdn_attention_core_fake (
1320+ mixed_qkv : torch .Tensor ,
1321+ b : torch .Tensor ,
1322+ a : torch .Tensor ,
1323+ core_attn_out : torch .Tensor ,
12761324 layer_name : str ,
12771325) -> None :
1326+ """Fake implementation for torch.compile."""
12781327 return
12791328
12801329
12811330direct_register_custom_op (
1282- op_name = "gdn_attention " ,
1283- op_func = gdn_attention ,
1284- mutates_args = ["output " ],
1285- fake_impl = gdn_attention_fake ,
1331+ op_name = "gdn_attention_core " ,
1332+ op_func = gdn_attention_core ,
1333+ mutates_args = ["core_attn_out " ],
1334+ fake_impl = gdn_attention_core_fake ,
12861335)
12871336
12881337
0 commit comments