diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index b095c79dc954..fbfd02d08d08 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -595,7 +595,10 @@ def _forward_core( mixed_qkv_non_spec ) - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + beta = b.sigmoid() + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: @@ -1335,13 +1338,12 @@ def gdn_attention_core_fake( ) +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) @triton.jit def fused_gdn_gating_kernel( g, - beta_output, A_log, a, - b, dt_bias, seq_len, NUM_HEADS: tl.constexpr, @@ -1355,7 +1357,6 @@ def fused_gdn_gating_kernel( mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) - blk_b = tl.load(b + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) @@ -1364,42 +1365,20 @@ def fused_gdn_gating_kernel( ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) - # compute beta_output = sigmoid(b) - blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32))) - tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask) def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, - b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fused computation of g and beta for Gated Delta Net. - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - beta_output = b.sigmoid() - TODO maybe use torch.compile to replace this triton kernel - """ +) -> torch.Tensor: batch, num_heads = a.shape seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + g = torch.empty_like(a, dtype=torch.float32) fused_gdn_gating_kernel[grid]( - g, - beta_output, - A_log, - a, - b, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1, + g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 ) - return g, beta_output + return g