-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[PERF] Decouple projections from GDN custom op. Attempt 2 #28083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,12 +30,14 @@ | |
| from vllm.forward_context import ForwardContext, get_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.fla.ops import ( | ||
| RMSNormGated, | ||
| chunk_gated_delta_rule, | ||
| fused_recurrent_gated_delta_rule, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe import SharedFusedMoE | ||
| from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm | ||
| from vllm.model_executor.layers.layernorm import ( | ||
| GemmaRMSNorm as Qwen3NextRMSNorm, | ||
| ) | ||
| from vllm.model_executor.layers.layernorm import RMSNormGated | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| QKVParallelLinear, | ||
|
|
@@ -436,17 +438,66 @@ def forward( | |
| hidden_states: torch.Tensor, | ||
| output: torch.Tensor, | ||
| ): | ||
| return torch.ops.vllm.gdn_attention( | ||
| hidden_states, | ||
| output, | ||
| """ | ||
| Forward pass with three parts: | ||
| 1. Input projection | ||
| 2. Core attention (custom op) | ||
| 3. Output projection | ||
| """ | ||
| num_tokens = hidden_states.size(0) | ||
|
|
||
| # ============================================================ | ||
| # Part 1: Input Projection | ||
| # ============================================================ | ||
| projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) | ||
| projected_states_ba, _ = self.in_proj_ba(hidden_states) | ||
| query, key, value, z, b, a = self.fix_query_key_value_ordering( | ||
| projected_states_qkvz, projected_states_ba | ||
| ) | ||
| query, key, value = map( | ||
| lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) | ||
| ) | ||
| mixed_qkv = torch.cat((query, key, value), dim=-1) | ||
|
|
||
| # ============================================================ | ||
| # Part 2: Core Attention (Custom Op) | ||
| # ============================================================ | ||
| core_attn_out = torch.zeros( | ||
| (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
|
|
||
| torch.ops.vllm.gdn_attention_core( | ||
| mixed_qkv, | ||
| b, | ||
| a, | ||
| core_attn_out, | ||
| self.prefix, | ||
| ) | ||
|
|
||
| def _forward( | ||
| # ============================================================ | ||
| # Part 3: Output Projection | ||
| # ============================================================ | ||
| z_shape_og = z.shape | ||
| # Reshape input data into 2D tensor | ||
| core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) | ||
| z = z.reshape(-1, z.shape[-1]) | ||
| core_attn_out = self.norm(core_attn_out, z) | ||
| core_attn_out = core_attn_out.reshape(z_shape_og) | ||
|
Comment on lines
436
to
+487
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The refactor now preallocates Useful? React with 👍 / 👎. |
||
| core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") | ||
| output[:num_tokens], _ = self.out_proj(core_attn_out) | ||
|
|
||
| def _forward_core( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| output: torch.Tensor, | ||
| mixed_qkv: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a: torch.Tensor, | ||
| core_attn_out: torch.Tensor, | ||
| ): | ||
| """ | ||
| Core attention computation (called by custom op). | ||
| """ | ||
| forward_context = get_forward_context() | ||
| attn_metadata: AttentionMetadata = forward_context.attn_metadata | ||
|
|
||
|
|
@@ -471,18 +522,11 @@ def _forward( | |
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
| num_accepted_tokens = attn_metadata.num_accepted_tokens | ||
|
|
||
| # 1. Set up dimensions for reshapes later | ||
| projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) | ||
| projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) | ||
| query, key, value, z, b, a = self.fix_query_key_value_ordering( | ||
| projected_states_qkvz, projected_states_ba | ||
| ) | ||
| query, key, value = map( | ||
| lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) | ||
| ) | ||
| mixed_qkv = torch.cat((query, key, value), dim=-1) | ||
| mixed_qkv = mixed_qkv[:num_actual_tokens] | ||
| b = b[:num_actual_tokens] | ||
| a = a[:num_actual_tokens] | ||
|
|
||
| # 2. Convolution sequence transformation | ||
| # 1. Convolution sequence transformation | ||
| conv_weights = self.conv1d.weight.view( | ||
| self.conv1d.weight.size(0), self.conv1d.weight.size(2) | ||
| ) | ||
|
|
@@ -498,7 +542,7 @@ def _forward( | |
| mixed_qkv_spec = None | ||
| mixed_qkv_non_spec = mixed_qkv | ||
|
|
||
| # 2.1: process the mutli-query part | ||
| # 1.1: Process the multi-query part | ||
| if spec_sequence_masks is not None: | ||
| mixed_qkv_spec = causal_conv1d_update( | ||
| mixed_qkv_spec, | ||
|
|
@@ -515,7 +559,7 @@ def _forward( | |
| validate_data=False, | ||
| ) | ||
|
|
||
| # 2.2: process the remaining part | ||
| # 1.2: Process the remaining part | ||
| if attn_metadata.num_prefills > 0: | ||
| mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) | ||
| # - "cache_indices" updates the conv_state cache in positions | ||
|
|
@@ -570,9 +614,9 @@ def _forward( | |
| g_non_spec = g | ||
| beta_non_spec = beta | ||
|
|
||
| # 3. Recurrent attention | ||
| # 2. Recurrent attention | ||
|
|
||
| # 3.1: process the mutlti-query part | ||
| # 2.1: Process the multi-query part | ||
| if spec_sequence_masks is not None: | ||
| core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( | ||
| q=query_spec, | ||
|
|
@@ -590,7 +634,7 @@ def _forward( | |
| else: | ||
| core_attn_out_spec, last_recurrent_state = None, None | ||
|
|
||
| # 3.2: process the remaining part | ||
| # 2.2: Process the remaining part | ||
| if attn_metadata.num_prefills > 0: | ||
| initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() | ||
| initial_state[~has_initial_state, ...] = 0 | ||
|
|
@@ -633,30 +677,20 @@ def _forward( | |
| else: | ||
| core_attn_out_non_spec, last_recurrent_state = None, None | ||
|
|
||
| # Merge core attention output | ||
| # 3. Merge core attention output | ||
| if spec_sequence_masks is not None and core_attn_out_non_spec is not None: | ||
| core_attn_out = torch.empty( | ||
| merged_out = torch.empty( | ||
| (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), | ||
| dtype=core_attn_out_non_spec.dtype, | ||
| device=core_attn_out_non_spec.device, | ||
| ) | ||
| core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec) | ||
| core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) | ||
|
|
||
| merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) | ||
| merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) | ||
| core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) | ||
| elif spec_sequence_masks is not None: | ||
| core_attn_out = core_attn_out_spec | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) | ||
| else: | ||
| core_attn_out = core_attn_out_non_spec | ||
|
|
||
| z_shape_og = z.shape | ||
| # reshape input data into 2D tensor | ||
| core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) | ||
| z = z.reshape(-1, z.shape[-1]) | ||
| core_attn_out = self.norm(core_attn_out, z) | ||
| core_attn_out = core_attn_out.reshape(z_shape_og) | ||
| core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") | ||
|
|
||
| output[:num_actual_tokens], _ = self.out_proj(core_attn_out) | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) | ||
|
|
||
|
|
||
| class Qwen3NextAttention(nn.Module): | ||
|
|
@@ -1260,29 +1294,44 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: | |
| return self.model.get_expert_mapping() | ||
|
|
||
|
|
||
| def gdn_attention( | ||
| hidden_states: torch.Tensor, | ||
| output: torch.Tensor, | ||
| def gdn_attention_core( | ||
| mixed_qkv: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a: torch.Tensor, | ||
| core_attn_out: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| """ | ||
| Custom op for the core attention computation. | ||
| Only handles the convolution + recurrent attention part. | ||
| Input/output projections are handled outside this op. | ||
| """ | ||
| forward_context: ForwardContext = get_forward_context() | ||
| self = forward_context.no_compile_layers[layer_name] | ||
| self._forward(hidden_states=hidden_states, output=output) | ||
| self._forward_core( | ||
| mixed_qkv=mixed_qkv, | ||
| b=b, | ||
| a=a, | ||
| core_attn_out=core_attn_out, | ||
| ) | ||
|
|
||
|
|
||
| def gdn_attention_fake( | ||
| hidden_states: torch.Tensor, | ||
| output: torch.Tensor, | ||
| def gdn_attention_core_fake( | ||
| mixed_qkv: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a: torch.Tensor, | ||
| core_attn_out: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| """Fake implementation for torch.compile.""" | ||
| return | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="gdn_attention", | ||
| op_func=gdn_attention, | ||
| mutates_args=["output"], | ||
| fake_impl=gdn_attention_fake, | ||
| op_name="gdn_attention_core", | ||
| op_func=gdn_attention_core, | ||
| mutates_args=["core_attn_out"], | ||
| fake_impl=gdn_attention_core_fake, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing
einopsinside theforward_nativemethod can introduce a performance overhead, as the import statement will be executed every time this method is called. It's better to move this import to the top of the file to ensure it's only executed once when the module is loaded.