Skip to content

Commit e980e99

Browse files
committed
[model] feat: Add Apertus (#3295)
Pre-release of Apertus from the Swiss AI Initiative Main modifications from Llama - xIELU Activation - QK-norm Associated Transformers PR huggingface/transformers#39381 Associated vLLM PR vllm-project/vllm#23068 Associated SGLang PR sgl-project/sglang#9774 GSM8K <img width="430" height="262" alt="image" src="https://github.com/user-attachments/assets/8b2d5188-834b-4a8c-828e-2d0aa2ccffed" /> <img width="436" height="266" alt="image" src="https://github.com/user-attachments/assets/57241a73-3150-474a-a4fb-222e33a0de08" />
1 parent bf127e4 commit e980e99

File tree

6 files changed

+201
-3
lines changed

6 files changed

+201
-3
lines changed

tests/models/test_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
1717
from transformers import (
18+
ApertusConfig,
1819
AutoModelForCausalLM,
1920
AutoModelForTokenClassification,
2021
GemmaConfig,
@@ -33,6 +34,7 @@
3334
MistralConfig(num_hidden_layers=1),
3435
GemmaConfig(num_hidden_layers=1),
3536
Qwen2Config(num_hidden_layers=1),
37+
ApertusConfig(num_hidden_layers=1),
3638
]
3739

3840

tests/models/test_transformers_ulysses.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import pytest
1919
import torch
2020
import torch.distributed
21+
import transformers
2122
from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input
23+
from packaging import version
2224
from torch.distributed import init_device_mesh
2325
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config
2426

@@ -46,7 +48,7 @@ class SequenceParallelConfig:
4648

4749

4850
def test_configs():
49-
return [
51+
configs = [
5052
SequenceParallelConfig(
5153
LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True
5254
),
@@ -68,6 +70,19 @@ def test_configs():
6870
),
6971
]
7072

73+
if version.parse(transformers.__version__) >= version.parse("4.56.0"):
74+
from transformers import ApertusConfig
75+
76+
configs.append(
77+
SequenceParallelConfig(
78+
ApertusConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32, hidden_size=4096),
79+
sp_size=8,
80+
is_valid=True,
81+
)
82+
)
83+
84+
return configs
85+
7186

7287
def sync_model_parameters_global(layer):
7388
# synchronize weights

tests/utils/test_flops_counter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from verl.utils.flops_counter import FlopsCounter
2020

21-
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text"}
21+
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"}
2222

2323

2424
class Config:
@@ -206,12 +206,30 @@ def __init__(self, config_dict):
206206
# total: 986195089686528 / 1e12 = 986.195089686528
207207
"expected_flops_tuple": (283517065887744 / 1e12, 986195089686528 / 1e12),
208208
},
209+
"apertus": {
210+
"config": { # swiss-ai/Apertus-8B
211+
"model_type": "apertus",
212+
"vocab_size": 131072,
213+
"hidden_size": 4096,
214+
"intermediate_size": 21504,
215+
"num_hidden_layers": 32,
216+
"num_attention_heads": 32,
217+
"num_key_value_heads": 32,
218+
"hidden_act": "xielu",
219+
# head_dim will be derived as 4096 / 32 = 128
220+
},
221+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
222+
# Calculation for Apertus (hidden_act="xielu" -> MLP uses [k_mlp=2]*H*I params; qk_norm=True -> [k_qkn=2]*H):
223+
# V=131072, H=4096, I=21504, L=32, k_mlp=2 (XIELU), k_qkn=2 (QK norm), S=6
224+
# S*(2*V*H + L*(4*H**2 + k_mlp*H*I + k_qkn*H)) * (SUM[seqlen]) + 12*SUM[seqlen**2]*L*H
225+
"expected_flops_tuple": (199154680725504 / 1e12, 732294071451648 / 1e12),
226+
},
209227
}
210228

211229

212230
@pytest.mark.parametrize(
213231
"config_type",
214-
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text"],
232+
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"],
215233
)
216234
def test_flops_counter(config_type: str):
217235
test_config = CONFIG[config_type]

verl/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
"mistral",
3333
("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"),
3434
),
35+
"ApertusForCausalLM": (
36+
"apertus",
37+
("ParallelApertusForCausalLMRmPadPP", "ParallelApertusForValueRmPadPP", "ParallelApertusForCausalLMRmPad"),
38+
),
3539
}
3640

3741

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2025 The SwissAI Initiative
2+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import sys
17+
from typing import Callable, Optional
18+
19+
import torch
20+
21+
if sys.version_info >= (3, 11):
22+
pass
23+
else:
24+
pass
25+
26+
from transformers.cache_utils import Cache
27+
from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
28+
from transformers.utils import logging
29+
30+
# Import compatibility wrapper for flash_attn_supports_top_left_mask
31+
from verl.utils.ulysses import (
32+
gather_heads_scatter_seq,
33+
gather_seq_scatter_heads,
34+
get_ulysses_sequence_parallel_world_size,
35+
validate_ulysses_config,
36+
)
37+
38+
logger = logging.get_logger(__name__)
39+
40+
41+
def apertus_attn_forward(
42+
self,
43+
hidden_states: torch.Tensor,
44+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
45+
attention_mask: Optional[torch.Tensor],
46+
past_key_value: Optional[Cache] = None,
47+
cache_position: Optional[torch.LongTensor] = None,
48+
**kwargs,
49+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
50+
"""
51+
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
52+
53+
Key differences from Llama attention:
54+
- QK normalization applied after Q/K projections
55+
56+
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.
57+
"""
58+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
59+
from transformers.models.apertus.modeling_apertus import eager_attention_forward
60+
61+
bsz, q_len, _ = hidden_states.shape
62+
63+
query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
64+
key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
65+
value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
66+
67+
query_states = self.q_norm(query_states)
68+
key_states = self.k_norm(key_states)
69+
70+
########## AlltoAll for Ulysses ##########
71+
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
72+
73+
if ulysses_sp_size > 1:
74+
validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)
75+
76+
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
77+
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
78+
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
79+
80+
full_q_len = query_states.size(2)
81+
82+
cos, sin = position_embeddings
83+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
84+
85+
if past_key_value is not None:
86+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
87+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
88+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
89+
90+
attention_interface: Callable = eager_attention_forward
91+
if self.config._attn_implementation != "eager":
92+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
93+
logger.warning_once(
94+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
95+
"Falling back to eager attention. This warning can be removed using the argument "
96+
'`attn_implementation="eager"` when loading the model.'
97+
)
98+
else:
99+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
100+
101+
attn_output, attn_weights = attention_interface(
102+
self,
103+
query_states,
104+
key_states,
105+
value_states,
106+
attention_mask,
107+
dropout=0.0 if not self.training else self.attention_dropout,
108+
scaling=self.scaling,
109+
**kwargs,
110+
)
111+
112+
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
113+
########## AlltoAll for Ulysses ##########
114+
if ulysses_sp_size > 1:
115+
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
116+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
117+
attn_output = self.o_proj(attn_output)
118+
return attn_output, attn_weights

verl/utils/flops_counter.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"mistral",
3232
"gemma3_text",
3333
"seed_oss",
34+
"apertus",
3435
}
3536

3637

@@ -132,6 +133,7 @@ def __init__(self, config: PretrainedConfig):
132133
"mistral": self._estimate_qwen2_flops,
133134
"gemma3_text": self._estimate_gemma3_flops,
134135
"seed_oss": self._estimate_qwen2_flops,
136+
"apertus": self._estimate_apertus_flops,
135137
}
136138
self.config = config
137139

@@ -329,6 +331,45 @@ def _estimate_gemma3_flops(self, tokens_sum, batch_seqlens, delta_time):
329331
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
330332
return flops_achieved
331333

334+
def _estimate_apertus_flops(self, tokens_sum, batch_seqlens, delta_time):
335+
hidden_size = self.config.hidden_size
336+
vocab_size = self.config.vocab_size
337+
num_hidden_layers = self.config.num_hidden_layers
338+
num_key_value_heads = self.config.num_key_value_heads
339+
num_attention_heads = self.config.num_attention_heads
340+
intermediate_size = self.config.intermediate_size
341+
342+
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
343+
q_size = num_attention_heads * head_dim
344+
k_size = num_key_value_heads * head_dim
345+
v_size = num_key_value_heads * head_dim
346+
347+
# Apertus MLP with XIELU activation uses only 2 linear layers (up_proj, down_proj)
348+
# No gate_proj for XIELU, unlike SwiGLU which has 3 layers
349+
mlp_N = hidden_size * intermediate_size * 2
350+
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
351+
352+
# ApertusConfig has qk_norm defaulting to True.
353+
# This adds params for q_norm (on H) and k_norm (on num_kv_heads * head_dim)
354+
qk_norm_params_per_layer = hidden_size + num_key_value_heads * head_dim # q_norm + k_norm
355+
356+
emd_and_lm_head_N = vocab_size * hidden_size * 2
357+
# non-attn all_layer params
358+
dense_N = (mlp_N + attn_linear_N + qk_norm_params_per_layer) * num_hidden_layers + emd_and_lm_head_N
359+
# non-attn all_layer & all_token fwd & bwd flops
360+
dense_N_flops = 6 * dense_N * tokens_sum
361+
362+
# attn all_layer & all_token fwd & bwd flops
363+
seqlen_square_sum = 0
364+
for seqlen in batch_seqlens:
365+
seqlen_square_sum += seqlen * seqlen
366+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
367+
368+
# all_layer & all_token fwd & bwd flops
369+
flops_all_token = dense_N_flops + attn_qkv_flops
370+
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
371+
return flops_achieved
372+
332373
def estimate_flops(self, batch_seqlens, delta_time):
333374
"""
334375
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.

0 commit comments

Comments
 (0)