From 77810b8ae5fa020136c3c6f3808859f621b58699 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Wed, 26 Feb 2025 14:23:35 +0100 Subject: [PATCH 01/44] init swissai model --- src/transformers/models/swissai/__init__.py | 14 + .../models/swissai/configuration_swissai.py | 172 ++++ .../models/swissai/modeling_swissai.py | 864 ++++++++++++++++++ 3 files changed, 1050 insertions(+) create mode 100644 src/transformers/models/swissai/__init__.py create mode 100644 src/transformers/models/swissai/configuration_swissai.py create mode 100644 src/transformers/models/swissai/modeling_swissai.py diff --git a/src/transformers/models/swissai/__init__.py b/src/transformers/models/swissai/__init__.py new file mode 100644 index 000000000000..764b8d2fbf7a --- /dev/null +++ b/src/transformers/models/swissai/__init__.py @@ -0,0 +1,14 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swissai import * + from .modeling_swissai import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/swissai/configuration_swissai.py b/src/transformers/models/swissai/configuration_swissai.py new file mode 100644 index 000000000000..19397b6df3c5 --- /dev/null +++ b/src/transformers/models/swissai/configuration_swissai.py @@ -0,0 +1,172 @@ +from ...configuration_utils import PretrainedConfig + + +class SwissAIConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwissAIModel`]. It is used to instantiate an OLMo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/SwissAI-7B-1124-hf](https://huggingface.co/allenai/SwissAI-7B-1124-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the SwissAI model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`SwissAIModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import SwissAIModel, SwissAIConfig + + >>> # Initializing a SwissAI 8B style configuration + >>> configuration = SwissAIConfig() + + >>> # Initializing a model from the SwissAI 8B style configuration + >>> model = SwissAIModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "swissai" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=131072, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="xielu", + max_position_embeddings=8192, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=131071, # TODO: what's our eos token id? + tie_word_embeddings=False, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +__all__ = ["SwissAIConfig"] diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py new file mode 100644 index 000000000000..b40b7f09b66a --- /dev/null +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -0,0 +1,864 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_swissai import SwissAIConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "SwissAIConfig" + +class XIELU(nn.Module): + def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6): + super(XIELU, self).__init__() + self.beta = beta + self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1).unsqueeze(0)) + self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1).unsqueeze(0)) + self.eps = torch.tensor(eps) + + def forward(self, x): + alpha_p = F.softplus(self.alpha_p) + alpha_n = self.beta + F.softplus(self.alpha_n) + return torch.where(x > 0, + alpha_p * x * x + self.beta * x, + alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x) + + +class SwissAIRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + SwissAIRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SwissAIAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = SwissAIRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = SwissAIRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class SwissAIMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = XIELU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + return down_proj + + +class SwissAIDecoderLayer(nn.Module): + def __init__(self, config: SwissAIConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = SwissAIAttention(config=config, layer_idx=layer_idx) + + self.mlp = SwissAIMLP(config) + self.pre_attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.pre_attention_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class SwissAIRotaryEmbedding(nn.Module): + def __init__(self, config: SwissAIConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +SWISSAI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SwissAIConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare SwissAI Model outputting raw hidden-states without any specific head on top.", + SWISSAI_START_DOCSTRING, +) +class SwissAIPreTrainedModel(PreTrainedModel): + config_class = SwissAIConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SwissAIDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +SWISSAI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare SwissAI Model outputting raw hidden-states without any specific head on top.", + SWISSAI_START_DOCSTRING, +) +class SwissAIModel(SwissAIPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SwissAIDecoderLayer`] + + Args: + config: SwissAIConfig + """ + + def __init__(self, config: SwissAIConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [SwissAIDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = SwissAIRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class SwissAIForCausalLM(SwissAIPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = SwissAIModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, SwissAIForCausalLM + + >>> model = SwissAIForCausalLM.from_pretrained("SwissAI-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("SwissAI-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SwissAIForCausalLM", "SwissAIModel", "SwissAIPreTrainedModel"] From bcdaf706b15e774c6348fa3eeab110d3bd9c827e Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Wed, 26 Feb 2025 14:35:32 +0100 Subject: [PATCH 02/44] AutoModelForCausalLM --- src/transformers/models/auto/modeling_auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 68830ccc3d72..9dc1d28724ab 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -323,6 +323,7 @@ ("swin", "SwinModel"), ("swin2sr", "Swin2SRModel"), ("swinv2", "Swinv2Model"), + ("swissai", "SwissAIModel"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("t5gemma", "T5GemmaModel"), From 53a3755cd6002cbd2ad5ad76f50dde94b9a156ac Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Wed, 26 Feb 2025 14:39:03 +0100 Subject: [PATCH 03/44] AutoModelForCausalLM mapping --- src/transformers/models/auto/modeling_auto.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9dc1d28724ab..6c9ed869b69b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -25,6 +25,7 @@ auto_class_update, ) from .configuration_auto import CONFIG_MAPPING_NAMES +from ..swissai import SwissAIForCausalLM logger = logging.get_logger(__name__) @@ -669,6 +670,7 @@ ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("stablelm", "StableLmForCausalLM"), ("starcoder2", "Starcoder2ForCausalLM"), + ("swissai", "SwissAIForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("whisper", "WhisperForCausalLM"), @@ -2098,6 +2100,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) + if config_class.model_type == "swissai": + return SwissAIForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) From 7c648e738a587aa4066882e227126bfa03830de0 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 28 Feb 2025 11:00:55 +0100 Subject: [PATCH 04/44] qk norm and post ln optional --- .../models/swissai/configuration_swissai.py | 12 ++++++- .../models/swissai/modeling_swissai.py | 34 ++++++++++++++----- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/swissai/configuration_swissai.py b/src/transformers/models/swissai/configuration_swissai.py index 19397b6df3c5..9269269d62bd 100644 --- a/src/transformers/models/swissai/configuration_swissai.py +++ b/src/transformers/models/swissai/configuration_swissai.py @@ -64,7 +64,11 @@ class SwissAIConfig(PretrainedConfig): The dropout ratio for the attention probabilities. rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. - + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use a normalization in the query and key projection layers during self-attention. + post_norm (`bool`, *optional*, defaults to `False`): + Whether to use a normalization after the self-attention and MLP layers, i.e. x = norm(f(x)) + x. + If `False`, the model will use a pre-normalization, i.e. x = f(norm(x)) + x. ```python >>> from transformers import SwissAIModel, SwissAIConfig @@ -88,6 +92,7 @@ class SwissAIConfig(PretrainedConfig): "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), @@ -116,6 +121,8 @@ def __init__( attention_bias=False, attention_dropout=0.0, rms_norm_eps=1e-5, + qk_norm=True, + post_norm=False, **kwargs, ): super().__init__( @@ -148,6 +155,9 @@ def __init__( self.rms_norm_eps = rms_norm_eps + self.qk_norm = qk_norm + self.post_norm = post_norm + def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index b40b7f09b66a..c897f238f9d8 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -160,8 +160,12 @@ def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = SwissAIRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = SwissAIRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + if self.config.qk_norm: + self.q_norm = SwissAIRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = SwissAIRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() def forward( self, @@ -225,10 +229,18 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = XIELU() + if config.hidden_act == "xielu": + self.act_fn = XIELU() + else: + self.act_fn = ACT2FN[config.hidden_act] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + if self.config.hidden_act == "xielu": + # in case of xielu, no gated MLP + down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -239,8 +251,8 @@ def __init__(self, config: SwissAIConfig, layer_idx: int): self.self_attn = SwissAIAttention(config=config, layer_idx=layer_idx) self.mlp = SwissAIMLP(config) - self.pre_attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -256,7 +268,8 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.pre_attention_layernorm(hidden_states) + if not self.config.post_norm: + hidden_states = self.attention_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( @@ -270,12 +283,17 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + if self.config.post_norm: + hidden_states = self.attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) + if not self.config.post_norm: + hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if self.config.post_norm: + hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) From d9a923dabf553ef580984c98ed163fb6a37ab3d5 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Sun, 2 Mar 2025 13:56:47 +0100 Subject: [PATCH 05/44] fix wrong shape of qk norm: megatron uses head_dim --- .../models/swissai/configuration_swissai.py | 5 ++--- .../models/swissai/modeling_swissai.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/swissai/configuration_swissai.py b/src/transformers/models/swissai/configuration_swissai.py index 9269269d62bd..2a62e83d7527 100644 --- a/src/transformers/models/swissai/configuration_swissai.py +++ b/src/transformers/models/swissai/configuration_swissai.py @@ -3,9 +3,8 @@ class SwissAIConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`SwissAIModel`]. It is used to instantiate an OLMo2 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the [allenai/SwissAI-7B-1124-hf](https://huggingface.co/allenai/SwissAI-7B-1124-hf). + This is the configuration class to store the configuration of a [`SwissAIModel`]. It is used to instantiate a SwissAI + model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index c897f238f9d8..ce806e4db6a7 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -161,8 +161,8 @@ def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) if self.config.qk_norm: - self.q_norm = SwissAIRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = SwissAIRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self.q_norm = SwissAIRMSNorm(self.head_dim, config.rms_norm_eps) + self.k_norm = SwissAIRMSNorm(self.head_dim, config.rms_norm_eps) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() @@ -179,12 +179,14 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - + query_states = self.q_proj(hidden_states) query_states = query_states.view(hidden_shape).transpose(1, 2) + query_states = self.q_norm(query_states) + key_states = self.k_proj(hidden_states) key_states = key_states.view(hidden_shape).transpose(1, 2) + key_states = self.k_norm(key_states) + + value_states = self.v_proj(hidden_states) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings From f35ee015e82e342c6821609a7c4431522a493478 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Sun, 2 Mar 2025 15:56:35 +0100 Subject: [PATCH 06/44] automodel fixes --- src/transformers/models/auto/configuration_auto.py | 2 ++ src/transformers/models/auto/modeling_auto.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e758f2eab81b..4cc3c9f86d3f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -347,6 +347,7 @@ ("swin", "SwinConfig"), ("swin2sr", "Swin2SRConfig"), ("swinv2", "Swinv2Config"), + ("swissai", "SwissAIConfig"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("t5gemma", "T5GemmaConfig"), @@ -752,6 +753,7 @@ ("swin", "Swin Transformer"), ("swin2sr", "Swin2SR"), ("swinv2", "Swin Transformer V2"), + ("swissai", "SwissAI"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), ("t5gemma", "T5Gemma"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6c9ed869b69b..993925bfcb3f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -25,7 +25,6 @@ auto_class_update, ) from .configuration_auto import CONFIG_MAPPING_NAMES -from ..swissai import SwissAIForCausalLM logger = logging.get_logger(__name__) @@ -2100,8 +2099,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - if config_class.model_type == "swissai": - return SwissAIForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) From e6921f7e138e9f02e9e4eab978c0459aebf4b014 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Sun, 2 Mar 2025 15:57:30 +0100 Subject: [PATCH 07/44] minor fix in forward --- src/transformers/models/swissai/modeling_swissai.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index ce806e4db6a7..7dececc80f3a 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -255,6 +255,8 @@ def __init__(self, config: SwissAIConfig, layer_idx: int): self.mlp = SwissAIMLP(config) self.attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_norm = config.post_norm def forward( self, @@ -270,7 +272,7 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - if not self.config.post_norm: + if not self.post_norm: hidden_states = self.attention_layernorm(hidden_states) # Self Attention @@ -285,16 +287,16 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) - if self.config.post_norm: + if self.post_norm: hidden_states = self.attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - if not self.config.post_norm: + if not self.post_norm: hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.config.post_norm: + if self.post_norm: hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states From 46ca1ae40f7f4a71bb7206ab6831aa507a4382a6 Mon Sep 17 00:00:00 2001 From: dhia680 Date: Fri, 30 May 2025 10:56:49 +0200 Subject: [PATCH 08/44] fix rope validation to accept llama3 scaling --- .../models/swissai/configuration_swissai.py | 27 +++---------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/swissai/configuration_swissai.py b/src/transformers/models/swissai/configuration_swissai.py index 2a62e83d7527..52847ebf2805 100644 --- a/src/transformers/models/swissai/configuration_swissai.py +++ b/src/transformers/models/swissai/configuration_swissai.py @@ -1,5 +1,5 @@ from ...configuration_utils import PretrainedConfig - +from ...modeling_rope_utils import rope_config_validation class SwissAIConfig(PretrainedConfig): r""" @@ -148,34 +148,15 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self._rope_scaling_validation() + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.rms_norm_eps = rms_norm_eps self.qk_norm = qk_norm self.post_norm = post_norm - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - __all__ = ["SwissAIConfig"] From 994b1d72239c65e6ba6fbc487d345c1279c20934 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Fri, 23 May 2025 03:07:59 +0200 Subject: [PATCH 09/44] `SwissAIForTokenClassification` support --- src/transformers/models/auto/modeling_auto.py | 1 + .../models/swissai/modeling_swissai.py | 91 ++++++++++++++++++- 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 993925bfcb3f..a3d145f617db 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1388,6 +1388,7 @@ ("squeezebert", "SqueezeBertForTokenClassification"), ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), + ("swissai", "SwissAIForTokenClassification"), ("t5", "T5ForTokenClassification"), ("t5gemma", "T5GemmaForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index 7dececc80f3a..a9c2a502bf7d 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -9,7 +9,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, TokenClassifierOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -19,6 +19,8 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + add_code_sample_docstrings, + can_return_tuple, ) from ...utils.deprecation import deprecate_kwarg from .configuration_swissai import SwissAIConfig @@ -185,7 +187,7 @@ def forward( key_states = self.k_proj(hidden_states) key_states = key_states.view(hidden_shape).transpose(1, 2) key_states = self.k_norm(key_states) - + value_states = self.v_proj(hidden_states) value_states = value_states.view(hidden_shape).transpose(1, 2) @@ -255,7 +257,7 @@ def __init__(self, config: SwissAIConfig, layer_idx: int): self.mlp = SwissAIMLP(config) self.attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.post_norm = config.post_norm def forward( @@ -883,4 +885,85 @@ def forward( ) -__all__ = ["SwissAIForCausalLM", "SwissAIModel", "SwissAIPreTrainedModel"] +@add_start_docstrings( + """ + The SwissAI Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + SWISSAI_START_DOCSTRING, +) +class SwissAIForTokenClassification(SwissAIPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = SwissAIModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SwissAIForCausalLM", "SwissAIModel", "SwissAIPreTrainedModel", "SwissAIForTokenClassification"] From 8b38b5a4603c6cf37dbd2625d8b9863515d1e03f Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 8 Jun 2025 18:31:06 +0200 Subject: [PATCH 10/44] Align `SwissAI` to v4.52.4 --- .../models/swissai/modeling_swissai.py | 347 +++++------------- 1 file changed, 102 insertions(+), 245 deletions(-) diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index a9c2a502bf7d..cfd6d835c83c 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -9,25 +9,22 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, TokenClassifierOutput -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - LossKwargs, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, - add_code_sample_docstrings, - can_return_tuple, -) -from ...utils.deprecation import deprecate_kwarg +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from ...integrations import use_kernel_forward_from_hub from .configuration_swissai import SwissAIConfig logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "SwissAIConfig" + class XIELU(nn.Module): def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6): @@ -45,6 +42,7 @@ def forward(self, x): alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x) +@use_kernel_forward_from_hub("RMSNorm") class SwissAIRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -65,6 +63,43 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +ALL_LAYERNORM_LAYERS.append(SwissAIRMSNorm) + + +class SwissAIRotaryEmbedding(nn.Module): + def __init__(self, config: SwissAIConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -99,6 +134,29 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +class SwissAIMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act == "xielu": + self.act_fn = XIELU() + else: + self.act_fn = ACT2FN[config.hidden_act] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + + def forward(self, x): + if self.config.hidden_act == "xielu": + # in case of xielu, no gated MLP + down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -200,6 +258,7 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -225,33 +284,11 @@ def forward( return attn_output, attn_weights -class SwissAIMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_act == "xielu": - self.act_fn = XIELU() - else: - self.act_fn = ACT2FN[config.hidden_act] - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - - def forward(self, x): - if self.config.hidden_act == "xielu": - # in case of xielu, no gated MLP - down_proj = self.down_proj(self.act_fn(self.up_proj(x))) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - class SwissAIDecoderLayer(nn.Module): def __init__(self, config: SwissAIConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size + self.self_attn = SwissAIAttention(config=config, layer_idx=layer_idx) self.mlp = SwissAIMLP(config) @@ -309,88 +346,7 @@ def forward( return outputs -class SwissAIRotaryEmbedding(nn.Module): - def __init__(self, config: SwissAIConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -SWISSAI_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`SwissAIConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare SwissAI Model outputting raw hidden-states without any specific head on top.", - SWISSAI_START_DOCSTRING, -) +@auto_docstring class SwissAIPreTrainedModel(PreTrainedModel): config_class = SwissAIConfig base_model_prefix = "model" @@ -417,93 +373,8 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -SWISSAI_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare SwissAI Model outputting raw hidden-states without any specific head on top.", - SWISSAI_START_DOCSTRING, -) +@auto_docstring class SwissAIModel(SwissAIPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SwissAIDecoderLayer`] - - Args: - config: SwissAIConfig - """ - def __init__(self, config: SwissAIConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -526,10 +397,11 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) + @can_return_tuple + @auto_docstring def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, @@ -540,7 +412,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -625,13 +497,12 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -639,7 +510,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -760,6 +631,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +@auto_docstring class SwissAIForCausalLM(SwissAIPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -792,15 +664,14 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @can_return_tuple + @auto_docstring def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -810,21 +681,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: @@ -849,7 +711,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -863,7 +725,7 @@ def forward( **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -885,13 +747,7 @@ def forward( ) -@add_start_docstrings( - """ - The SwissAI Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - SWISSAI_START_DOCSTRING, -) +@auto_docstring class SwissAIForTokenClassification(SwissAIPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -916,11 +772,7 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @can_return_tuple - @add_start_docstrings_to_model_forward(SWISSAI_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -966,4 +818,9 @@ def forward( ) -__all__ = ["SwissAIForCausalLM", "SwissAIModel", "SwissAIPreTrainedModel", "SwissAIForTokenClassification"] +__all__ = [ + "SwissAIForCausalLM", + "SwissAIModel", + "SwissAIPreTrainedModel", + "SwissAIForTokenClassification", +] From 0ffc9b9e42425c08490594ed952261218c72f9e7 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sat, 12 Jul 2025 23:35:58 +0200 Subject: [PATCH 11/44] Align `SwissAI` to v4.53.1 --- .../models/swissai/modeling_swissai.py | 309 ++++-------------- 1 file changed, 55 insertions(+), 254 deletions(-) diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index cfd6d835c83c..a7d0ff8aa3f8 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -1,14 +1,15 @@ -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch -import torch.nn as nn +from torch import nn from torch.nn import functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -17,9 +18,8 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging -from ...integrations import use_kernel_forward_from_hub +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs from .configuration_swissai import SwissAIConfig @@ -63,14 +63,11 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -ALL_LAYERNORM_LAYERS.append(SwissAIRMSNorm) - - class SwissAIRotaryEmbedding(nn.Module): def __init__(self, config: SwissAIConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" @@ -177,7 +174,7 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -230,25 +227,21 @@ def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - query_states = query_states.view(hidden_shape).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_norm(query_states) - key_states = self.k_proj(hidden_states) - key_states = key_states.view(hidden_shape).transpose(1, 2) key_states = self.k_norm(key_states) - value_states = self.v_proj(hidden_states) - value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -258,15 +251,8 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -284,7 +270,7 @@ def forward( return attn_output, attn_weights -class SwissAIDecoderLayer(nn.Module): +class SwissAIDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SwissAIConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -303,24 +289,20 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: residual = hidden_states - if not self.post_norm: hidden_states = self.attention_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -338,12 +320,7 @@ def forward( if self.post_norm: hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs + return hidden_states @auto_docstring @@ -354,12 +331,17 @@ class SwissAIPreTrainedModel(PreTrainedModel): _no_split_modules = ["SwissAIDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": SwissAIDecoderLayer, + "attentions": SwissAIAttention, + } def _init_weights(self, module): std = self.config.initializer_range @@ -371,6 +353,8 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SwissAIRMSNorm): + module.weight.data.fill_(1.0) @auto_docstring @@ -397,7 +381,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -406,230 +390,59 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( + cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, + past_key_values=past_key_values, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - @auto_docstring class SwissAIForCausalLM(SwissAIPreTrainedModel, GenerationMixin): @@ -675,12 +488,10 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -704,13 +515,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -718,8 +523,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, @@ -782,8 +585,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + **kwargs, ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -799,8 +601,7 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + **kwargs, ) sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) From 7793c8784587350554d65a9b607047bb36b604e2 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sat, 12 Jul 2025 23:37:11 +0200 Subject: [PATCH 12/44] Init CUDA xIELU --- src/transformers/activations.py | 69 +++++++++++++++++++ .../models/swissai/modeling_swissai.py | 24 +------ 2 files changed, 72 insertions(+), 21 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 423e3cee8760..8f81d8f600be 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -185,6 +185,74 @@ def __getitem__(self, key): return cls(**kwargs) +class xIELUActivation(nn.Module): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + + If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6, dtype=torch.bfloat16, with_vector_loads=True): + super().__init__() + self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(0)) + self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0)) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + + self._xielu_cuda_obj = None + try: + import xielu.ops + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + logger.warning_once( + "CUDA-fused xIELU currently in development. Please use the Python version for now.", + str(err) + ) + try: + from torch._dynamo import allow_in_graph + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + except Exception as err: + logger.warning_once( + "Could not enable torch._dynamo for xIELU (%s) - this may result in slower performance.", + str(err) + ) + self._xielu_cuda_fn = self._xielu_cuda + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err) + ) + + def _xielu_python(self, x: Tensor) -> Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x + ) + + def _xielu_cuda(self, x: Tensor) -> Tensor: + """Firewall function to prevent torch.compile from seeing .item() calls""" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once("Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", original_shape, x.shape) + result = self._xielu_cuda_obj.forward(x, self.alpha_p, self.alpha_n, self.beta.item(), self.eps.item(), self.with_vector_loads) + return result.view(original_shape) + + def forward(self, input: Tensor) -> Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + return self._xielu_cuda_fn(input) + return self._xielu_python(input) + + ACT2CLS = { "gelu": GELUActivation, "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), @@ -206,6 +274,7 @@ def __getitem__(self, key): "swish": nn.SiLU, "tanh": nn.Tanh, "prelu": nn.PReLU, + "xielu": xIELUActivation, } ACT2FN = ClassInstantier(ACT2CLS) diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/swissai/modeling_swissai.py index a7d0ff8aa3f8..b3ed7dfd12f6 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/swissai/modeling_swissai.py @@ -26,22 +26,6 @@ logger = logging.get_logger(__name__) -class XIELU(nn.Module): - def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6): - super(XIELU, self).__init__() - self.beta = beta - self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1).unsqueeze(0)) - self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1).unsqueeze(0)) - self.eps = torch.tensor(eps) - - def forward(self, x): - alpha_p = F.softplus(self.alpha_p) - alpha_n = self.beta + F.softplus(self.alpha_n) - return torch.where(x > 0, - alpha_p * x * x + self.beta * x, - alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x) - - @use_kernel_forward_from_hub("RMSNorm") class SwissAIRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -139,10 +123,8 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_act == "xielu": - self.act_fn = XIELU() - else: - self.act_fn = ACT2FN[config.hidden_act] + self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act != "xielu": self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) def forward(self, x): @@ -150,7 +132,7 @@ def forward(self, x): # in case of xielu, no gated MLP down_proj = self.down_proj(self.act_fn(self.up_proj(x))) else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj From 590957b78c81786c49b58b670354843e95c74594 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sat, 12 Jul 2025 23:50:43 +0200 Subject: [PATCH 13/44] `SwissAI*`->`Apertus*` --- .../models/{swissai => apertus}/__init__.py | 4 +- .../configuration_apertus.py} | 22 +++--- .../modeling_apertus.py} | 76 +++++++++---------- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 6 +- 5 files changed, 56 insertions(+), 56 deletions(-) rename src/transformers/models/{swissai => apertus}/__init__.py (80%) rename src/transformers/models/{swissai/configuration_swissai.py => apertus/configuration_apertus.py} (93%) rename src/transformers/models/{swissai/modeling_swissai.py => apertus/modeling_apertus.py} (92%) diff --git a/src/transformers/models/swissai/__init__.py b/src/transformers/models/apertus/__init__.py similarity index 80% rename from src/transformers/models/swissai/__init__.py rename to src/transformers/models/apertus/__init__.py index 764b8d2fbf7a..de4b061700ac 100644 --- a/src/transformers/models/swissai/__init__.py +++ b/src/transformers/models/apertus/__init__.py @@ -5,8 +5,8 @@ if TYPE_CHECKING: - from .configuration_swissai import * - from .modeling_swissai import * + from .configuration_apertus import * + from .modeling_apertus import * else: import sys diff --git a/src/transformers/models/swissai/configuration_swissai.py b/src/transformers/models/apertus/configuration_apertus.py similarity index 93% rename from src/transformers/models/swissai/configuration_swissai.py rename to src/transformers/models/apertus/configuration_apertus.py index 52847ebf2805..d800690efa01 100644 --- a/src/transformers/models/swissai/configuration_swissai.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -1,9 +1,9 @@ from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation -class SwissAIConfig(PretrainedConfig): +class ApertusConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`SwissAIModel`]. It is used to instantiate a SwissAI + This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -12,8 +12,8 @@ class SwissAIConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 50304): - Vocabulary size of the SwissAI model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`SwissAIModel`] + Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ApertusModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): @@ -69,20 +69,20 @@ class SwissAIConfig(PretrainedConfig): Whether to use a normalization after the self-attention and MLP layers, i.e. x = norm(f(x)) + x. If `False`, the model will use a pre-normalization, i.e. x = f(norm(x)) + x. ```python - >>> from transformers import SwissAIModel, SwissAIConfig + >>> from transformers import ApertusModel, ApertusConfig - >>> # Initializing a SwissAI 8B style configuration - >>> configuration = SwissAIConfig() + >>> # Initializing a Apertus 8B style configuration + >>> configuration = ApertusConfig() - >>> # Initializing a model from the SwissAI 8B style configuration - >>> model = SwissAIModel(configuration) + >>> # Initializing a model from the Apertus 8B style configuration + >>> model = ApertusModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` """ - model_type = "swissai" + model_type = "apertus" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k @@ -159,4 +159,4 @@ def __init__( self.post_norm = post_norm -__all__ = ["SwissAIConfig"] +__all__ = ["ApertusConfig"] diff --git a/src/transformers/models/swissai/modeling_swissai.py b/src/transformers/models/apertus/modeling_apertus.py similarity index 92% rename from src/transformers/models/swissai/modeling_swissai.py rename to src/transformers/models/apertus/modeling_apertus.py index b3ed7dfd12f6..85742906ee3d 100644 --- a/src/transformers/models/swissai/modeling_swissai.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -20,17 +20,17 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import check_model_inputs -from .configuration_swissai import SwissAIConfig +from .configuration_apertus import ApertusConfig logger = logging.get_logger(__name__) @use_kernel_forward_from_hub("RMSNorm") -class SwissAIRMSNorm(nn.Module): +class ApertusRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - SwissAIRMSNorm is equivalent to T5LayerNorm + ApertusRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -47,8 +47,8 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class SwissAIRotaryEmbedding(nn.Module): - def __init__(self, config: SwissAIConfig, device=None): +class ApertusRotaryEmbedding(nn.Module): + def __init__(self, config: ApertusConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): @@ -115,7 +115,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class SwissAIMLP(nn.Module): +class ApertusMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config @@ -174,10 +174,10 @@ def eager_attention_forward( return attn_output, attn_weights -class SwissAIAttention(nn.Module): +class ApertusAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): + def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -200,8 +200,8 @@ def __init__(self, config: SwissAIConfig, layer_idx: Optional[int] = None): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) if self.config.qk_norm: - self.q_norm = SwissAIRMSNorm(self.head_dim, config.rms_norm_eps) - self.k_norm = SwissAIRMSNorm(self.head_dim, config.rms_norm_eps) + self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) + self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() @@ -252,16 +252,16 @@ def forward( return attn_output, attn_weights -class SwissAIDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SwissAIConfig, layer_idx: int): +class ApertusDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ApertusConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = SwissAIAttention(config=config, layer_idx=layer_idx) + self.self_attn = ApertusAttention(config=config, layer_idx=layer_idx) - self.mlp = SwissAIMLP(config) - self.attention_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.feedforward_layernorm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = ApertusMLP(config) + self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_norm = config.post_norm @@ -306,11 +306,11 @@ def forward( @auto_docstring -class SwissAIPreTrainedModel(PreTrainedModel): - config_class = SwissAIConfig +class ApertusPreTrainedModel(PreTrainedModel): + config_class = ApertusConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["SwissAIDecoderLayer"] + _no_split_modules = ["ApertusDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_flash_attn_3 = True @@ -321,8 +321,8 @@ class SwissAIPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { - "hidden_states": SwissAIDecoderLayer, - "attentions": SwissAIAttention, + "hidden_states": ApertusDecoderLayer, + "attentions": ApertusAttention, } def _init_weights(self, module): @@ -335,23 +335,23 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, SwissAIRMSNorm): + elif isinstance(module, ApertusRMSNorm): module.weight.data.fill_(1.0) @auto_docstring -class SwissAIModel(SwissAIPreTrainedModel): - def __init__(self, config: SwissAIConfig): +class ApertusModel(ApertusPreTrainedModel): + def __init__(self, config: ApertusConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [SwissAIDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = SwissAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = SwissAIRotaryEmbedding(config=config) + self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = ApertusRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -427,14 +427,14 @@ def forward( @auto_docstring -class SwissAIForCausalLM(SwissAIPreTrainedModel, GenerationMixin): +class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) - self.model = SwissAIModel(config) + self.model = ApertusModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -484,10 +484,10 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, SwissAIForCausalLM + >>> from transformers import AutoTokenizer, ApertusForCausalLM - >>> model = SwissAIForCausalLM.from_pretrained("SwissAI-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("SwissAI-2-7b-hf") + >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -533,11 +533,11 @@ def forward( @auto_docstring -class SwissAIForTokenClassification(SwissAIPreTrainedModel): +class ApertusForTokenClassification(ApertusPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = SwissAIModel(config) + self.model = ApertusModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: @@ -602,8 +602,8 @@ def forward( __all__ = [ - "SwissAIForCausalLM", - "SwissAIModel", - "SwissAIPreTrainedModel", - "SwissAIForTokenClassification", + "ApertusForCausalLM", + "ApertusModel", + "ApertusPreTrainedModel", + "ApertusForTokenClassification", ] diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4cc3c9f86d3f..cadc1907f2ad 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -41,6 +41,7 @@ ("albert", "AlbertConfig"), ("align", "AlignConfig"), ("altclip", "AltCLIPConfig"), + ("apertus", "ApertusConfig"), ("arcee", "ArceeConfig"), ("aria", "AriaConfig"), ("aria_text", "AriaTextConfig"), @@ -347,7 +348,6 @@ ("swin", "SwinConfig"), ("swin2sr", "Swin2SRConfig"), ("swinv2", "Swinv2Config"), - ("swissai", "SwissAIConfig"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("t5gemma", "T5GemmaConfig"), @@ -419,6 +419,7 @@ ("albert", "ALBERT"), ("align", "ALIGN"), ("altclip", "AltCLIP"), + ("apertus", "Apertus"), ("arcee", "Arcee"), ("aria", "Aria"), ("aria_text", "AriaText"), @@ -753,7 +754,6 @@ ("swin", "Swin Transformer"), ("swin2sr", "Swin2SR"), ("swinv2", "Swin Transformer V2"), - ("swissai", "SwissAI"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), ("t5gemma", "T5Gemma"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a3d145f617db..979b490ba7d4 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -37,6 +37,7 @@ ("albert", "AlbertModel"), ("align", "AlignModel"), ("altclip", "AltCLIPModel"), + ("apertus", "ApertusModel"), ("arcee", "ArceeModel"), ("aria", "AriaModel"), ("aria_text", "AriaTextModel"), @@ -323,7 +324,6 @@ ("swin", "SwinModel"), ("swin2sr", "Swin2SRModel"), ("swinv2", "Swinv2Model"), - ("swissai", "SwissAIModel"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("t5gemma", "T5GemmaModel"), @@ -560,6 +560,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("apertus", "ApertusForCausalLM"), ("arcee", "ArceeForCausalLM"), ("aria_text", "AriaTextForCausalLM"), ("bamba", "BambaForCausalLM"), @@ -669,7 +670,6 @@ ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("stablelm", "StableLmForCausalLM"), ("starcoder2", "Starcoder2ForCausalLM"), - ("swissai", "SwissAIForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("whisper", "WhisperForCausalLM"), @@ -1316,6 +1316,7 @@ [ # Model for Token Classification mapping ("albert", "AlbertForTokenClassification"), + ("apertus", "ApertusForTokenClassification"), ("arcee", "ArceeForTokenClassification"), ("bert", "BertForTokenClassification"), ("big_bird", "BigBirdForTokenClassification"), @@ -1388,7 +1389,6 @@ ("squeezebert", "SqueezeBertForTokenClassification"), ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), - ("swissai", "SwissAIForTokenClassification"), ("t5", "T5ForTokenClassification"), ("t5gemma", "T5GemmaForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), From 353c6c0adf2c0b9c413fd9dd86db7ac0f10876e4 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 00:27:41 +0200 Subject: [PATCH 14/44] ci fix --- src/transformers/activations.py | 44 ++++++++++++++----- .../models/apertus/configuration_apertus.py | 5 ++- .../models/apertus/modeling_apertus.py | 28 +++++++++--- 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 8f81d8f600be..337570777e3e 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -193,36 +193,45 @@ class xIELUActivation(nn.Module): Otherwise, we emit a single warning and use xIELU Python """ - def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6, dtype=torch.bfloat16, with_vector_loads=True): + def __init__( + self, + alpha_p_init=0.8, + alpha_n_init=0.8, + beta=0.5, + eps=-1e-6, + dtype=torch.bfloat16, + with_vector_loads=True, + ): super().__init__() self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(0)) - self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0)) + self.alpha_n = nn.Parameter( + torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0) + ) self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) self.with_vector_loads = with_vector_loads self._xielu_cuda_obj = None try: - import xielu.ops + import xielu.ops # noqa: F401 + self._xielu_cuda_obj = torch.classes.xielu.XIELU() - logger.warning_once( - "CUDA-fused xIELU currently in development. Please use the Python version for now.", - str(err) - ) + logger.warning_once("CUDA-fused xIELU currently in development. Please use the Python version for now.") try: from torch._dynamo import allow_in_graph + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) except Exception as err: logger.warning_once( "Could not enable torch._dynamo for xIELU (%s) - this may result in slower performance.", - str(err) + str(err), ) self._xielu_cuda_fn = self._xielu_cuda except Exception as err: logger.warning_once( "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n" "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", - str(err) + str(err), ) def _xielu_python(self, x: Tensor) -> Tensor: @@ -231,7 +240,7 @@ def _xielu_python(self, x: Tensor) -> Tensor: return torch.where( x > 0, alpha_p * x * x + self.beta * x, - alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x + alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x, ) def _xielu_cuda(self, x: Tensor) -> Tensor: @@ -243,8 +252,19 @@ def _xielu_cuda(self, x: Tensor) -> Tensor: if x.dim() > 3: x = x.view(-1, 1, x.size(-1)) if original_shape != x.shape: - logger.warning_once("Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", original_shape, x.shape) - result = self._xielu_cuda_obj.forward(x, self.alpha_p, self.alpha_n, self.beta.item(), self.eps.item(), self.with_vector_loads) + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p, + self.alpha_n, + self.beta.item(), + self.eps.item(), + self.with_vector_loads, + ) return result.view(original_shape) def forward(self, input: Tensor) -> Tensor: diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index d800690efa01..91bbd3e7e00e 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -1,6 +1,7 @@ from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation + class ApertusConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus @@ -113,7 +114,7 @@ def __init__( use_cache=True, pad_token_id=1, bos_token_id=None, - eos_token_id=131071, # TODO: what's our eos token id? + eos_token_id=131071, # TODO: what's our eos token id? tie_word_embeddings=False, rope_theta=500000.0, rope_scaling=None, @@ -121,7 +122,7 @@ def __init__( attention_dropout=0.0, rms_norm_eps=1e-5, qk_norm=True, - post_norm=False, + post_norm=False, **kwargs, ): super().__init__( diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 85742906ee3d..b76c1b295083 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -2,7 +2,6 @@ import torch from torch import nn -from torch.nn import functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache @@ -188,16 +187,24 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.is_causal = True self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, ) self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, ) self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, ) self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, ) if self.config.qk_norm: self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) @@ -390,7 +397,9 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position: torch.Tensor = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: @@ -517,7 +526,12 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (logits,) + outputs[1:] From 833f5fe32eb50fd964b83c9c60585e096eca9143 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 05:06:39 +0200 Subject: [PATCH 15/44] check_docstring ignore ApertusConfig --- src/transformers/models/apertus/configuration_apertus.py | 5 ++++- utils/check_docstrings.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index 91bbd3e7e00e..e12dfe70c5fa 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -5,7 +5,10 @@ class ApertusConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus - model according to the specified arguments, defining the model architecture. + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apertus-8B. + e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 946f678fb33f..9ae8e1da46dc 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -79,6 +79,7 @@ # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the # line before the docstring. OBJECTS_TO_IGNORE = [ + "ApertusConfig", "SmolLM3Config", "Gemma3nVisionConfig", "Llama4Processor", From f0ec65c8361ddd7ca35d0561f568ced30a60a4e5 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 06:05:28 +0200 Subject: [PATCH 16/44] Licensing and placeholder tests --- src/transformers/models/apertus/__init__.py | 18 + .../models/apertus/configuration_apertus.py | 18 + .../models/apertus/modeling_apertus.py | 18 + tests/models/apertus/__init__.py | 0 tests/models/apertus/test_modeling_apertus.py | 607 ++++++++++++++++++ 5 files changed, 661 insertions(+) create mode 100644 tests/models/apertus/__init__.py create mode 100644 tests/models/apertus/test_modeling_apertus.py diff --git a/src/transformers/models/apertus/__init__.py b/src/transformers/models/apertus/__init__.py index de4b061700ac..e12640dbe389 100644 --- a/src/transformers/models/apertus/__init__.py +++ b/src/transformers/models/apertus/__init__.py @@ -1,3 +1,21 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. +# +# This code is based on HuggingFace's LLaMA implementation in this library. +# It has been modified from its original forms to accommodate minor architectural +# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING from ...utils import _LazyModule diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index e12dfe70c5fa..7cf5dd5647e3 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -1,3 +1,21 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# +# This code is based on HuggingFace's LLaMA implementation in this library. +# It has been modified from its original forms to accommodate minor architectural +# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index b76c1b295083..5f0e5112fd0b 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -1,3 +1,21 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# +# This code is based on HuggingFace's LLaMA implementation in this library. +# It has been modified from its original forms to accommodate minor architectural +# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Optional, Union import torch diff --git a/tests/models/apertus/__init__.py b/tests/models/apertus/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py new file mode 100644 index 000000000000..384461257db4 --- /dev/null +++ b/tests/models/apertus/test_modeling_apertus.py @@ -0,0 +1,607 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. +# +# This code is based on HuggingFace's LLaMA implementation in this library. +# It has been modified from its original forms to accommodate minor architectural +# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Apertus model.""" + +import unittest + +from packaging import version + +from transformers import AutoTokenizer, StaticCache, is_torch_available +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + Expectations, + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + run_test_using_subprocess, + slow, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + import torch + + from transformers import ( + ApertusConfig, + ApertusForCausalLM, + ApertusForTokenClassification, + ApertusModel, + ) + from transformers.models.apertus.modeling_apertus import ApertusRotaryEmbedding + + +class ApertusModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = ApertusConfig + base_model_class = ApertusModel + causal_lm_class = ApertusForCausalLM + token_class = ApertusForTokenClassification + + +@require_torch +class ApertusModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = ( + ( + ApertusModel, + ApertusForCausalLM, + ApertusForTokenClassification, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": ApertusModel, + "text-generation": ApertusForCausalLM, + "token-classification": ApertusForTokenClassification, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = ApertusModelTester + rotary_embedding_layer = ApertusRotaryEmbedding # Enables RoPE tests if set + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = ApertusForCausalLM if is_torch_available() else None + + +@require_torch_accelerator +@require_read_token +class ApertusIntegrationTest(unittest.TestCase): + def setup(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=True) + + @slow + def test_apertus_8b_hard(self): + """ + An integration test for apertus 8b. It tests against a long output to ensure the subtle numerical differences + from apertus 8b's RoPE can be detected + """ + expected_texts = Expectations( + { + ("rocm", (9, 5)): 'Tell me about the french revolution. The french revolution was a period of radical social and political upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. This marked the beginning of the end of the absolute monarchy and the rise of the middle class.\n', + ("cuda", None): 'Tell me about the french revolution. The french revolution was a period of radical political and social upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. The National Assembly adopted the Declaration of the Rights of Man and of the Citizen, which enshr', + } + ) # fmt: skip + EXPECTED_TEXT = expected_texts.get_expectation() + + tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") + model = ApertusForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) + + @slow + def test_model_8b_logits_bf16(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = ApertusForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + # Expected mean on dim = -1 + + # fmt: off + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), + ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), + ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) + }) + # fmt: on + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + def test_model_8b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = ApertusForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16 + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + + # fmt: off + # Expected mean on dim = -1 + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), + ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), + ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), + ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) + }) + # fmt: on + + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + # TODO: check why we have the following strange situation. + # without running in subprocess, this test causes subsequent tests failing with `RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!` + @run_test_using_subprocess + @slow + def test_model_8b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") + model = ApertusForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", device_map="sequential", torch_dtype=torch.bfloat16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_torch_accelerator + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B", pad_token="", padding_side="right") + model = ApertusForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", device_map=torch_device, torch_dtype=torch.bfloat16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + apertus_models = { + "swiss-ai/Apertus-8B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + } + + for apertus_model_ckp, EXPECTED_TEXT_COMPLETION in apertus_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(apertus_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = ApertusForCausalLM.from_pretrained( + apertus_model_ckp, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": device, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + model_name = "swiss-ai/Apertus-8B" + self.model_dtype = torch.bfloat16 + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = ApertusForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, + attention_mask=mask_1b, + position_ids=position_ids_1b, + past_key_values=past_key_values_a, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + def test_stacked_causal_mask_static_cache(self): + """same as above but with StaticCache""" + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + padded_attention_mask = torch.nn.functional.pad( + input=mask_shared_prefix, + pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, + attention_mask=padded_attention_mask, + position_ids=position_ids_shared_prefix, + cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), + past_key_values=past_key_values, + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask_static_cache(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + # forward run for the first part of input + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + padded_mask_1a = torch.nn.functional.pad( + input=mask_1a, + pad=(0, max_cache_len - mask_1a.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + _ = self.model.forward( + input_1a, + attention_mask=padded_mask_1a, + position_ids=position_ids_1a, + cache_position=torch.arange(part_a, device=torch_device), + past_key_values=past_key_values, + ) + + # forward run for the second part of input + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + + padded_mask_1b = torch.nn.functional.pad( + input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 + ) + + outs_1b = self.model.forward( + input_1b, + attention_mask=padded_mask_1b, + position_ids=position_ids_1b, + cache_position=torch.arange( + part_a, + input_ids_shared_prefix.shape[-1], + device=torch_device, + ), + past_key_values=past_key_values, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) From 1f4e71589e897872c73f93721a81451997b32cd1 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 06:33:48 +0200 Subject: [PATCH 17/44] Placeholder doc --- docs/source/en/model_doc/apertus.md | 100 ++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 docs/source/en/model_doc/apertus.md diff --git a/docs/source/en/model_doc/apertus.md b/docs/source/en/model_doc/apertus.md new file mode 100644 index 000000000000..edf91a6e0c3b --- /dev/null +++ b/docs/source/en/model_doc/apertus.md @@ -0,0 +1,100 @@ + + +
+
+ PyTorch + FlashAttention + SDPA + Tensor parallelism +
+
+ +# Apertus + +[Apertus](https://www.swiss-ai.org) is a family of large language models from the Swiss AI Initiative. + +> [!TIP] +> Coming soon + +The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line. + + + + +```py +import torch +from transformers import pipeline + +pipeline = pipeline( + task="text-generation", + model="swiss-ai/Apertus-8B", + torch_dtype=torch.bfloat16, + device=0 +) +pipeline("Plants create energy through a process known as") +``` + + + + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained( + "swiss-ai/Apertus-8B", +) +model = AutoModelForCausalLM.from_pretrained( + "swiss-ai/Apertus-8B", + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa" +) +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") + +output = model.generate(**input_ids, cache_implementation="static") +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + + + + +```bash +echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model swiss-ai/Apertus-8B --device 0 +``` + + + + +## ApertusConfig + +[[autodoc]] ApertusConfig + +## ApertusModel + +[[autodoc]] ApertusModel + - forward + +## ApertusForCausalLM + +[[autodoc]] ApertusForCausalLM + - forward + +## ApertusForTokenClassification + +[[autodoc]] ApertusForTokenClassification + - forward From cf125820f825a477dde9fc84ba54231cb1d76ee4 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 13:16:55 +0200 Subject: [PATCH 18/44] XIELU syntax --- src/transformers/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 337570777e3e..7454d63a72ab 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -185,7 +185,7 @@ def __getitem__(self, key): return cls(**kwargs) -class xIELUActivation(nn.Module): +class XIELUActivation(nn.Module): """ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 @@ -294,7 +294,7 @@ def forward(self, input: Tensor) -> Tensor: "swish": nn.SiLU, "tanh": nn.Tanh, "prelu": nn.PReLU, - "xielu": xIELUActivation, + "xielu": XIELUActivation, } ACT2FN = ClassInstantier(ACT2CLS) From 331fc0d2787b9dc9ab0100106856074941b537fc Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 13:18:37 +0200 Subject: [PATCH 19/44] `_xielu_python` optimization --- src/transformers/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 7454d63a72ab..95d7e2c2a0eb 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -239,8 +239,8 @@ def _xielu_python(self, x: Tensor) -> Tensor: alpha_n = self.beta + nn.functional.softplus(self.alpha_n) return torch.where( x > 0, - alpha_p * x * x + self.beta * x, - alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x, + (alpha_p * x + self.beta) * x, + alpha_n * torch.expm1(torch.min(x, self.eps)) - (alpha_n + self.beta) * x, ) def _xielu_cuda(self, x: Tensor) -> Tensor: From 2728d3cbd649d1c587739000db13686e001c827c Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 13 Jul 2025 15:28:55 +0200 Subject: [PATCH 20/44] Fix xIELU --- src/transformers/activations.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 95d7e2c2a0eb..fe3be33ca951 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -200,7 +200,7 @@ def __init__( beta=0.5, eps=-1e-6, dtype=torch.bfloat16, - with_vector_loads=True, + with_vector_loads=False, ): super().__init__() self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(0)) @@ -216,7 +216,6 @@ def __init__( import xielu.ops # noqa: F401 self._xielu_cuda_obj = torch.classes.xielu.XIELU() - logger.warning_once("CUDA-fused xIELU currently in development. Please use the Python version for now.") try: from torch._dynamo import allow_in_graph @@ -239,8 +238,8 @@ def _xielu_python(self, x: Tensor) -> Tensor: alpha_n = self.beta + nn.functional.softplus(self.alpha_n) return torch.where( x > 0, - (alpha_p * x + self.beta) * x, - alpha_n * torch.expm1(torch.min(x, self.eps)) - (alpha_n + self.beta) * x, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x, ) def _xielu_cuda(self, x: Tensor) -> Tensor: From d0d42cdde4934633f42b1efe000379d0ab009249 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 05:21:26 +0200 Subject: [PATCH 21/44] [tmp] `{beta,eps}` persistent=False until {beta,eps} saved in checkpoint --- src/transformers/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index fe3be33ca951..4597eb4ca180 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -207,8 +207,8 @@ def __init__( self.alpha_n = nn.Parameter( torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0) ) - self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) - self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype), persistent=False) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype), persistent=False) self.with_vector_loads = with_vector_loads self._xielu_cuda_obj = None From 543b343044f6fc1e17c5f5b62643aace25c0ab4d Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 05:48:26 +0200 Subject: [PATCH 22/44] Modular `Apertus` --- src/transformers/models/apertus/__init__.py | 4 +- .../models/apertus/configuration_apertus.py | 123 ++++-- .../models/apertus/modeling_apertus.py | 274 ++++++------ .../models/apertus/modular_apertus.py | 410 ++++++++++++++++++ 4 files changed, 615 insertions(+), 196 deletions(-) create mode 100644 src/transformers/models/apertus/modular_apertus.py diff --git a/src/transformers/models/apertus/__init__.py b/src/transformers/models/apertus/__init__.py index e12640dbe389..dea6f28438b4 100644 --- a/src/transformers/models/apertus/__init__.py +++ b/src/transformers/models/apertus/__init__.py @@ -2,8 +2,8 @@ # Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. # # This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate minor architectural -# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# It has been modified from its original forms to accommodate the architectural +# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index 7cf5dd5647e3..735bfc6140ee 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -1,9 +1,15 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_apertus.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. # # This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate minor architectural -# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# It has been modified from its original forms to accommodate the architectural +# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +22,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -27,18 +34,17 @@ class ApertusConfig(PretrainedConfig): defaults will yield a similar configuration to that of the Apertus-8B. e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 50304): + vocab_size (`int`, *optional*, defaults to 131072): Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`ApertusModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): + intermediate_size (`int`, *optional*, defaults to 14336): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. @@ -49,60 +55,88 @@ class ApertusConfig(PretrainedConfig): `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`): The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. + max_position_embeddings (`int`, *optional*, defaults to 65536): + The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 1): + pad_token_id (`int`, *optional*, defaults to 3): Padding token id. - bos_token_id (`int`, *optional*): + bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 50279): + eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): + rope_theta (`float`, *optional*, defaults to 12000000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. qk_norm (`bool`, *optional*, defaults to `True`): - Whether to use a normalization in the query and key projection layers during self-attention. + Whether to use a normalization on the query and key states during self-attention. post_norm (`bool`, *optional*, defaults to `False`): - Whether to use a normalization after the self-attention and MLP layers, i.e. x = norm(f(x)) + x. - If `False`, the model will use a pre-normalization, i.e. x = f(norm(x)) + x. + Whether to use a normalization on the output of the attention layer. + ```python >>> from transformers import ApertusModel, ApertusConfig - >>> # Initializing a Apertus 8B style configuration + >>> # Initializing a Apertus-8B style configuration >>> configuration = ApertusConfig() - >>> # Initializing a model from the Apertus 8B style configuration + >>> # Initializing a model from the Apertus-8B style configuration >>> model = ApertusModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ``` - """ + ```""" model_type = "apertus" keys_to_ignore_at_inference = ["past_key_values"] @@ -130,18 +164,24 @@ def __init__( num_attention_heads=32, num_key_value_heads=None, hidden_act="xielu", - max_position_embeddings=8192, + max_position_embeddings=65536, initializer_range=0.02, + rms_norm_eps=1e-5, use_cache=True, - pad_token_id=1, - bos_token_id=None, - eos_token_id=131071, # TODO: what's our eos token id? + pad_token_id=3, + bos_token_id=1, + eos_token_id=2, tie_word_embeddings=False, - rope_theta=500000.0, - rope_scaling=None, + rope_theta=12000000.0, + rope_scaling={ + "rope_type": "llama3", + "factor": 8.0, + "original_max_position_embeddings": 8192, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + }, attention_bias=False, attention_dropout=0.0, - rms_norm_eps=1e-5, qk_norm=True, post_norm=False, **kwargs, @@ -167,16 +207,17 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.rms_norm_eps = rms_norm_eps - self.qk_norm = qk_norm self.post_norm = post_norm diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 5f0e5112fd0b..305530395f40 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -1,9 +1,15 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_apertus.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. # # This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate minor architectural -# differences compared to LLaMA used by the Swiss AI Initiative that trained the model. +# It has been modified from its original forms to accommodate the architectural +# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,20 +33,34 @@ from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - TokenClassifierOutput, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, TokenClassifierOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs from .configuration_apertus import ApertusConfig -logger = logging.get_logger(__name__) +class ApertusMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act != "xielu": + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + + def forward(self, x): + if self.config.hidden_act == "xielu": + # in case of xielu, no gated MLP + down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj @use_kernel_forward_from_hub("RMSNorm") @@ -132,27 +152,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class ApertusMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - if config.hidden_act != "xielu": - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - - def forward(self, x): - if self.config.hidden_act == "xielu": - # in case of xielu, no gated MLP - down_proj = self.down_proj(self.act_fn(self.up_proj(x))) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -205,24 +204,16 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.is_causal = True self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) if self.config.qk_norm: self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) @@ -253,7 +244,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -298,13 +288,12 @@ def forward( past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states if not self.post_norm: hidden_states = self.attention_layernorm(hidden_states) - # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -364,95 +353,6 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -@auto_docstring -class ApertusModel(ApertusPreTrainedModel): - def __init__(self, config: ApertusConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = ApertusRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @check_model_inputs - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPast: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - ) - - @auto_docstring class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -497,7 +397,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -524,7 +423,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -532,7 +430,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -544,16 +441,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, @@ -633,9 +521,89 @@ def forward( ) -__all__ = [ - "ApertusForCausalLM", - "ApertusModel", - "ApertusPreTrainedModel", - "ApertusForTokenClassification", -] +@auto_docstring +class ApertusModel(ApertusPreTrainedModel): + def __init__(self, config: ApertusConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = ApertusRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +__all__ = ["ApertusForCausalLM", "ApertusForTokenClassification", "ApertusModel", "ApertusPreTrainedModel"] diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py new file mode 100644 index 000000000000..7cd443227d65 --- /dev/null +++ b/src/transformers/models/apertus/modular_apertus.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# +# This code is based on HuggingFace's LLaMA implementation in this library. +# It has been modified from its original forms to accommodate the architectural +# differences made by the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...utils import TransformersKwargs, logging +from ..llama.configuration_llama import LlamaConfig +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + eager_attention_forward, + apply_rotary_pos_emb, +) +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class ApertusConfig(LlamaConfig): + r""" + This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apertus-8B. + e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ApertusModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 65536): + The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 3): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 12000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use a normalization on the query and key states during self-attention. + post_norm (`bool`, *optional*, defaults to `False`): + Whether to use a normalization on the output of the attention layer. + + ```python + >>> from transformers import ApertusModel, ApertusConfig + + >>> # Initializing a Apertus-8B style configuration + >>> configuration = ApertusConfig() + + >>> # Initializing a model from the Apertus-8B style configuration + >>> model = ApertusModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "apertus" + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + } + + def __init__( + self, + vocab_size=131072, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="xielu", + max_position_embeddings=65536, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=3, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=12000000.0, + rope_scaling={ + "rope_type": "llama3", + "factor": 8.0, + "original_max_position_embeddings": 8192, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0 + }, + attention_bias=False, + attention_dropout=0.0, + qk_norm=True, + post_norm=False, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + **kwargs, + ) + self.qk_norm = qk_norm + self.post_norm = post_norm + del self.pretraining_tp + del self.mlp_bias + del self.head_dim + + +class ApertusMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act != "xielu": + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + + def forward(self, x): + if self.config.hidden_act == "xielu": + # in case of xielu, no gated MLP + down_proj = self.down_proj(self.act_fn(self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class ApertusRMSNorm(LlamaRMSNorm): + pass + + +class ApertusRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class ApertusAttention(LlamaAttention): + def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + if self.config.qk_norm: + self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) + self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class ApertusDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: ApertusConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_norm = config.post_norm + + del self.input_layernorm + del self.post_attention_layernorm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + if not self.post_norm: + hidden_states = self.attention_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if self.post_norm: + hidden_states = self.attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + if not self.post_norm: + hidden_states = self.feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.post_norm: + hidden_states = self.feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class ApertusPreTrainedModel(LlamaPreTrainedModel): + pass + + +class ApertusForCausalLM(LlamaForCausalLM): + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, ApertusForCausalLM + + >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward(**super_kwargs) + pass + + +class ApertusForTokenClassification(LlamaForTokenClassification): + pass + + +class ApertusModel(LlamaModel): + pass + + + + +__all__ = [ + "ApertusConfig", + "ApertusForCausalLM", + "ApertusForTokenClassification", + "ApertusModel", + "ApertusPreTrainedModel", +] \ No newline at end of file From 4d436d093fbeeb61bb5cead1be711d74110ebb87 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 05:48:59 +0200 Subject: [PATCH 23/44] CUDA xIELU logging --- src/transformers/activations.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 4597eb4ca180..754ba184e6b7 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -216,16 +216,16 @@ def __init__( import xielu.ops # noqa: F401 self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = f"Using experimental xIELU CUDA." try: from torch._dynamo import allow_in_graph self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += f" Enabled torch._dynamo for xIELU CUDA." except Exception as err: - logger.warning_once( - "Could not enable torch._dynamo for xIELU (%s) - this may result in slower performance.", - str(err), - ) + msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance." self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) except Exception as err: logger.warning_once( "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n" @@ -268,7 +268,12 @@ def _xielu_cuda(self, x: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: - return self._xielu_cuda_fn(input) + if not torch._dynamo.is_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning( + "torch._dynamo is compiling, using Python version of xIELU." + ) return self._xielu_python(input) From e5ec231621ff6343dcb741ff5d0d561dfdec6079 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 06:02:31 +0200 Subject: [PATCH 24/44] ci fix --- src/transformers/activations.py | 4 ++-- src/transformers/models/apertus/modular_apertus.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 754ba184e6b7..70a222202e6e 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -216,12 +216,12 @@ def __init__( import xielu.ops # noqa: F401 self._xielu_cuda_obj = torch.classes.xielu.XIELU() - msg = f"Using experimental xIELU CUDA." + msg = "Using experimental xIELU CUDA." try: from torch._dynamo import allow_in_graph self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) - msg += f" Enabled torch._dynamo for xIELU CUDA." + msg += " Enabled torch._dynamo for xIELU CUDA." except Exception as err: msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance." self._xielu_cuda_fn = self._xielu_cuda diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 7cd443227d65..189ef33b5b89 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -22,6 +22,9 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( @@ -33,12 +36,9 @@ LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding, - eager_attention_forward, apply_rotary_pos_emb, + eager_attention_forward, ) -from ...cache_utils import Cache -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...processing_utils import Unpack logger = logging.get_logger(__name__) @@ -407,4 +407,4 @@ class ApertusModel(LlamaModel): "ApertusForTokenClassification", "ApertusModel", "ApertusPreTrainedModel", -] \ No newline at end of file +] From 1de44fd5fd4fe1f852263349806fea3f128b7273 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 06:33:30 +0200 Subject: [PATCH 25/44] ci fix --- src/transformers/activations.py | 4 +- .../models/apertus/modeling_apertus.py | 42 +++---------------- .../models/apertus/modular_apertus.py | 6 +-- 3 files changed, 9 insertions(+), 43 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 70a222202e6e..6d09bbc46a03 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -271,9 +271,7 @@ def forward(self, input: Tensor) -> Tensor: if not torch._dynamo.is_compiling(): return self._xielu_cuda_fn(input) else: - logger.warning( - "torch._dynamo is compiling, using Python version of xIELU." - ) + logger.warning("torch._dynamo is compiling, using Python version of xIELU.") return self._xielu_python(input) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 305530395f40..a392286cf3ca 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -32,8 +32,8 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, TokenClassifierOutput +from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -321,7 +321,7 @@ def forward( @auto_docstring class ApertusPreTrainedModel(PreTrainedModel): - config_class = ApertusConfig + config: ApertusConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ApertusDecoderLayer"] @@ -330,28 +330,14 @@ class ApertusPreTrainedModel(PreTrainedModel): _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True + + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": ApertusDecoderLayer, "attentions": ApertusAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, ApertusRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin): @@ -368,18 +354,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def set_decoder(self, decoder): self.model = decoder @@ -539,12 +513,6 @@ def __init__(self, config: ApertusConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @check_model_inputs @auto_docstring def forward( diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 189ef33b5b89..95babdf05ce9 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -154,6 +154,7 @@ class ApertusConfig(LlamaConfig): >>> # Accessing the model configuration >>> configuration = model.config ```""" + model_type = "apertus" base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k @@ -188,7 +189,7 @@ def __init__( "factor": 8.0, "original_max_position_embeddings": 8192, "low_freq_factor": 1.0, - "high_freq_factor": 4.0 + "high_freq_factor": 4.0, }, attention_bias=False, attention_dropout=0.0, @@ -388,6 +389,7 @@ def forward(self, **super_kwargs): "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" return super().forward(**super_kwargs) + pass @@ -399,8 +401,6 @@ class ApertusModel(LlamaModel): pass - - __all__ = [ "ApertusConfig", "ApertusForCausalLM", From 9c0cb617716bec339cc844188a23c832f0045264 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Thu, 14 Aug 2025 06:42:07 +0200 Subject: [PATCH 26/44] ci fix --- .../models/apertus/modeling_apertus.py | 69 +------------------ 1 file changed, 2 insertions(+), 67 deletions(-) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index a392286cf3ca..f15b0e84a8c8 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -426,73 +426,8 @@ def forward( ) -@auto_docstring -class ApertusForTokenClassification(ApertusPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = ApertusModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - **kwargs, - ) -> TokenClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - sequence_output = outputs.last_hidden_state - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) +class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTrainedModel): + pass @auto_docstring From 8f1c0817cad29563602344c34de5481926174fff Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:34:29 +0200 Subject: [PATCH 27/44] Update license Co-authored-by: Cyril Vallez --- src/transformers/models/apertus/modular_apertus.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 95babdf05ce9..1e508ad1cff2 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -1,9 +1,6 @@ # coding=utf-8 -# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. # -# This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate the architectural -# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From dad00ca960ab5b034610c9697ef4ad3dfc1d1093 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:41:43 +0200 Subject: [PATCH 28/44] Update tests/models/apertus/test_modeling_apertus.py Co-authored-by: Cyril Vallez --- tests/models/apertus/test_modeling_apertus.py | 262 +----------------- 1 file changed, 2 insertions(+), 260 deletions(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index 384461257db4..9b7848047199 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -94,267 +94,9 @@ class ApertusModelTest(CausalLMModelTest, unittest.TestCase): @require_torch_accelerator @require_read_token +@slow class ApertusIntegrationTest(unittest.TestCase): - def setup(self): - cleanup(torch_device, gc_collect=True) - - def tearDown(self): - # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves - # some memory allocated in the cache, which means some object is not being released properly. This causes some - # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. - # Investigate the root cause. - cleanup(torch_device, gc_collect=True) - - @slow - def test_apertus_8b_hard(self): - """ - An integration test for apertus 8b. It tests against a long output to ensure the subtle numerical differences - from apertus 8b's RoPE can be detected - """ - expected_texts = Expectations( - { - ("rocm", (9, 5)): 'Tell me about the french revolution. The french revolution was a period of radical social and political upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. This marked the beginning of the end of the absolute monarchy and the rise of the middle class.\n', - ("cuda", None): 'Tell me about the french revolution. The french revolution was a period of radical political and social upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. The National Assembly adopted the Declaration of the Rights of Man and of the Citizen, which enshr', - } - ) # fmt: skip - EXPECTED_TEXT = expected_texts.get_expectation() - - tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") - model = ApertusForCausalLM.from_pretrained( - "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16 - ) - input_text = ["Tell me about the french revolution."] - model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) - - generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(generated_text, EXPECTED_TEXT) - - @slow - def test_model_8b_logits_bf16(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = ApertusForCausalLM.from_pretrained( - "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - # Expected mean on dim = -1 - - # fmt: off - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), - ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), - ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), - ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), - ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) - }) - # fmt: on - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) - - @slow - def test_model_8b_logits(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = ApertusForCausalLM.from_pretrained( - "swiss-ai/Apertus-8B", device_map="auto", torch_dtype=torch.bfloat16 - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - - # fmt: off - # Expected mean on dim = -1 - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), - ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), - ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), - ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) - }) - # fmt: on - - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) - - # TODO: check why we have the following strange situation. - # without running in subprocess, this test causes subsequent tests failing with `RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!` - @run_test_using_subprocess - @slow - def test_model_8b_dola_generation(self): - # ground truth text generated with dola_layers="low", repetition_penalty=1.2 - EXPECTED_TEXT_COMPLETION = ( - "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " - "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " - "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " - "understanding of space and time." - ) - prompt = "Simply put, the theory of relativity states that " - tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B") - model = ApertusForCausalLM.from_pretrained( - "swiss-ai/Apertus-8B", device_map="sequential", torch_dtype=torch.bfloat16 - ) - model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - - # greedy generation outputs - generated_ids = model.generate( - **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" - ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - @slow - @require_torch_accelerator - def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - if version.parse(torch.__version__) < version.parse("2.3.0"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - NUM_TOKENS_TO_GENERATE = 40 - # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test - # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " - "theory of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " - "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] - - prompts = [ - "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", - ] - tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B", pad_token="", padding_side="right") - model = ApertusForCausalLM.from_pretrained( - "swiss-ai/Apertus-8B", device_map=torch_device, torch_dtype=torch.bfloat16 - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - - # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - @slow - def test_export_static_cache(self): - if version.parse(torch.__version__) < version.parse("2.4.0"): - self.skipTest(reason="This test requires torch >= 2.4 to run.") - - from transformers.integrations.executorch import ( - TorchExportableModuleWithStaticCache, - ) - - apertus_models = { - "swiss-ai/Apertus-8B": [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all " - "observers, regardless of their location, and 2) the laws of physics are the same for all observers" - ], - } - - for apertus_model_ckp, EXPECTED_TEXT_COMPLETION in apertus_models.items(): - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(apertus_model_ckp, pad_token="", padding_side="right") - max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ - "input_ids" - ].shape[-1] - - # Load model - device = "cpu" - dtype = torch.bfloat16 - cache_implementation = "static" - attn_implementation = "sdpa" - batch_size = 1 - model = ApertusForCausalLM.from_pretrained( - apertus_model_ckp, - device_map=device, - torch_dtype=dtype, - attn_implementation=attn_implementation, - generation_config=GenerationConfig( - use_cache=True, - cache_implementation=cache_implementation, - max_length=max_generation_length, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_generation_length, - "device": device, - }, - ), - ) - - prompts = ["Simply put, the theory of relativity states that "] - prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - prompt_token_ids = prompt_tokens["input_ids"] - max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] - - # Static Cache + export - from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() - ep_generated_ids = TorchExportableModuleWithStaticCache.generate( - exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens - ) - ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + pass @slow From cd029ab3bceb0381089944b56a29ea71cf0d720a Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:10:32 +0200 Subject: [PATCH 29/44] `.utils.import_utils.is_torchdynamo_compiling` --- src/transformers/activations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 6d09bbc46a03..2d14d0f84311 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -19,6 +19,7 @@ from torch import Tensor, nn from .utils import logging +from .utils.import_utils import is_torchdynamo_compiling logger = logging.get_logger(__name__) @@ -268,7 +269,7 @@ def _xielu_cuda(self, x: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: - if not torch._dynamo.is_compiling(): + if not is_torchdynamo_compiling(): return self._xielu_cuda_fn(input) else: logger.warning("torch._dynamo is compiling, using Python version of xIELU.") From 250b43af28445336a4a6d656275db4272a4c68e1 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:12:36 +0200 Subject: [PATCH 30/44] `Apertus` class ordering --- src/transformers/models/apertus/modular_apertus.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 1e508ad1cff2..536922fe946f 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -361,6 +361,10 @@ class ApertusPreTrainedModel(LlamaPreTrainedModel): pass +class ApertusModel(LlamaModel): + pass + + class ApertusForCausalLM(LlamaForCausalLM): def forward(self, **super_kwargs): r""" @@ -394,14 +398,10 @@ class ApertusForTokenClassification(LlamaForTokenClassification): pass -class ApertusModel(LlamaModel): - pass - - __all__ = [ "ApertusConfig", + "ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", - "ApertusModel", "ApertusPreTrainedModel", ] From c4b6d76cdbb92aa4c9f524b637489b0be813e0a7 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:31:37 +0200 Subject: [PATCH 31/44] `past_key_value{->s}`, `make fix-copies` --- .../models/apertus/configuration_apertus.py | 5 +- .../models/apertus/modeling_apertus.py | 183 +++++++++--------- .../models/apertus/modular_apertus.py | 10 +- 3 files changed, 98 insertions(+), 100 deletions(-) diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index 735bfc6140ee..40f14a26137c 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -5,11 +5,8 @@ # modular_apertus.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. # -# This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate the architectural -# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index f15b0e84a8c8..881884150e0b 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -5,11 +5,8 @@ # modular_apertus.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 EleutherAI, the HuggingFace Inc. team, and the Swiss AI Initiative. All rights reserved. +# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved. # -# This code is based on HuggingFace's LLaMA implementation in this library. -# It has been modified from its original forms to accommodate the architectural -# differences made by the Swiss AI Initiative that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,6 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_apertus import ApertusConfig @@ -85,6 +83,8 @@ def extra_repr(self): class ApertusRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: ApertusConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" @@ -222,12 +222,13 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.q_norm = nn.Identity() self.k_norm = nn.Identity() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -243,9 +244,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -280,12 +281,13 @@ def __init__(self, config: ApertusConfig, layer_idx: int): self.post_norm = config.post_norm + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -298,7 +300,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -326,8 +328,7 @@ class ApertusPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ApertusDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -339,6 +340,85 @@ class ApertusPreTrainedModel(PreTrainedModel): } +@auto_docstring +class ApertusModel(ApertusPreTrainedModel): + def __init__(self, config: ApertusConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = ApertusRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + @auto_docstring class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -430,83 +510,4 @@ class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTra pass -@auto_docstring -class ApertusModel(ApertusPreTrainedModel): - def __init__(self, config: ApertusConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = ApertusRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - @check_model_inputs - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPast: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - ) - - -__all__ = ["ApertusForCausalLM", "ApertusForTokenClassification", "ApertusModel", "ApertusPreTrainedModel"] +__all__ = ["ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", "ApertusPreTrainedModel"] diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 536922fe946f..4cc7f5b70e68 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -267,7 +267,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -283,9 +283,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -323,7 +323,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -336,7 +336,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, From 9865539763094ad485b6cada0501d27349711b19 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Wed, 20 Aug 2025 01:50:33 +0200 Subject: [PATCH 32/44] ci fix --- tests/models/apertus/test_modeling_apertus.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index 9b7848047199..080051f19676 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -20,17 +20,12 @@ import unittest -from packaging import version - from transformers import AutoTokenizer, StaticCache, is_torch_available -from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( - Expectations, cleanup, require_read_token, require_torch, require_torch_accelerator, - run_test_using_subprocess, slow, torch_device, ) From a7abf5ebe6d21303990fc121370e1e145c7df865 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 24 Aug 2025 19:53:10 +0200 Subject: [PATCH 33/44] Remove unused configuration parameters --- src/transformers/activations.py | 2 +- .../models/apertus/configuration_apertus.py | 8 ---- .../models/apertus/modeling_apertus.py | 29 +++------------ .../models/apertus/modular_apertus.py | 37 +++---------------- 4 files changed, 11 insertions(+), 65 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 2d14d0f84311..1bfe8e488606 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -272,7 +272,7 @@ def forward(self, input: Tensor) -> Tensor: if not is_torchdynamo_compiling(): return self._xielu_cuda_fn(input) else: - logger.warning("torch._dynamo is compiling, using Python version of xIELU.") + logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.") return self._xielu_python(input) diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index 40f14a26137c..180ad756dc88 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -117,10 +117,6 @@ class ApertusConfig(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - qk_norm (`bool`, *optional*, defaults to `True`): - Whether to use a normalization on the query and key states during self-attention. - post_norm (`bool`, *optional*, defaults to `False`): - Whether to use a normalization on the output of the attention layer. ```python >>> from transformers import ApertusModel, ApertusConfig @@ -179,8 +175,6 @@ def __init__( }, attention_bias=False, attention_dropout=0.0, - qk_norm=True, - post_norm=False, **kwargs, ): super().__init__( @@ -215,8 +209,6 @@ def __init__( if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - self.qk_norm = qk_norm - self.post_norm = post_norm __all__ = ["ApertusConfig"] diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 881884150e0b..c47fcdac3b37 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -49,16 +49,9 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - if config.hidden_act != "xielu": - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) def forward(self, x): - if self.config.hidden_act == "xielu": - # in case of xielu, no gated MLP - down_proj = self.down_proj(self.act_fn(self.up_proj(x))) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj(self.act_fn(self.up_proj(x))) @use_kernel_forward_from_hub("RMSNorm") @@ -215,12 +208,8 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - if self.config.qk_norm: - self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) - self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) - else: - self.q_norm = nn.Identity() - self.k_norm = nn.Identity() + self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) + self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -279,8 +268,6 @@ def __init__(self, config: ApertusConfig, layer_idx: int): self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_norm = config.post_norm - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, @@ -294,8 +281,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states - if not self.post_norm: - hidden_states = self.attention_layernorm(hidden_states) + hidden_states = self.attention_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -306,17 +292,12 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) - if self.post_norm: - hidden_states = self.attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - if not self.post_norm: - hidden_states = self.feedforward_layernorm(hidden_states) + hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.post_norm: - hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 4cc7f5b70e68..58267f9f1ffe 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -134,10 +134,6 @@ class ApertusConfig(LlamaConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - qk_norm (`bool`, *optional*, defaults to `True`): - Whether to use a normalization on the query and key states during self-attention. - post_norm (`bool`, *optional*, defaults to `False`): - Whether to use a normalization on the output of the attention layer. ```python >>> from transformers import ApertusModel, ApertusConfig @@ -190,8 +186,6 @@ def __init__( }, attention_bias=False, attention_dropout=0.0, - qk_norm=True, - post_norm=False, **kwargs, ): super().__init__( @@ -216,8 +210,6 @@ def __init__( attention_dropout=attention_dropout, **kwargs, ) - self.qk_norm = qk_norm - self.post_norm = post_norm del self.pretraining_tp del self.mlp_bias del self.head_dim @@ -232,16 +224,9 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - if config.hidden_act != "xielu": - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) def forward(self, x): - if self.config.hidden_act == "xielu": - # in case of xielu, no gated MLP - down_proj = self.down_proj(self.act_fn(self.up_proj(x))) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj(self.act_fn(self.up_proj(x))) class ApertusRMSNorm(LlamaRMSNorm): @@ -255,12 +240,8 @@ class ApertusRotaryEmbedding(LlamaRotaryEmbedding): class ApertusAttention(LlamaAttention): def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - if self.config.qk_norm: - self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) - self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) - else: - self.q_norm = nn.Identity() - self.k_norm = nn.Identity() + self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) + self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) def forward( self, @@ -313,8 +294,6 @@ def __init__(self, config: ApertusConfig, layer_idx: int): self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_norm = config.post_norm - del self.input_layernorm del self.post_attention_layernorm @@ -330,8 +309,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states - if not self.post_norm: - hidden_states = self.attention_layernorm(hidden_states) + hidden_states = self.attention_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -342,17 +320,12 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) - if self.post_norm: - hidden_states = self.attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - if not self.post_norm: - hidden_states = self.feedforward_layernorm(hidden_states) + hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.post_norm: - hidden_states = self.feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states From 273da51dfea2124e9d6fcd34af4ec5deaeb954a1 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Sun, 24 Aug 2025 19:55:01 +0200 Subject: [PATCH 34/44] `{beta,eps}` saved in checkpoint --- src/transformers/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 1bfe8e488606..3d726e335664 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -208,8 +208,8 @@ def __init__( self.alpha_n = nn.Parameter( torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0) ) - self.register_buffer("beta", torch.tensor(beta, dtype=dtype), persistent=False) - self.register_buffer("eps", torch.tensor(eps, dtype=dtype), persistent=False) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) self.with_vector_loads = with_vector_loads self._xielu_cuda_obj = None From 29da453b82d972c935315cd52417ecc7c6b87a55 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:32:26 +0200 Subject: [PATCH 35/44] `{beta,eps}` Temporarily on CPU --- src/transformers/activations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 3d726e335664..813cd2c3c811 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -211,6 +211,9 @@ def __init__( self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) self._xielu_cuda_obj = None try: @@ -261,8 +264,9 @@ def _xielu_cuda(self, x: Tensor) -> Tensor: x, self.alpha_p, self.alpha_n, - self.beta.item(), - self.eps.item(), + # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, self.with_vector_loads, ) return result.view(original_shape) From 792b7de755722d234bce5148499d894696185351 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Wed, 27 Aug 2025 21:31:08 +0200 Subject: [PATCH 36/44] Suggestions Co-authored-by: Cyril Vallez --- src/transformers/models/apertus/modular_apertus.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 58267f9f1ffe..6eedca9fe514 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -36,6 +36,7 @@ apply_rotary_pos_emb, eager_attention_forward, ) +from ..nemotron.modeling_nemotron import NemotronMLP logger = logging.get_logger(__name__) @@ -215,18 +216,11 @@ def __init__( del self.head_dim -class ApertusMLP(nn.Module): +class ApertusMLP(NemotronMLP): def __init__(self, config): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) class ApertusRMSNorm(LlamaRMSNorm): @@ -364,8 +358,6 @@ def forward(self, **super_kwargs): ```""" return super().forward(**super_kwargs) - pass - class ApertusForTokenClassification(LlamaForTokenClassification): pass From a5889da2f56000223c8e595e699f4d9ce8b75c07 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Wed, 27 Aug 2025 21:46:33 +0200 Subject: [PATCH 37/44] ci fix --- src/transformers/models/apertus/modeling_apertus.py | 6 ------ src/transformers/models/apertus/modular_apertus.py | 1 - 2 files changed, 7 deletions(-) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index c47fcdac3b37..8ca279ded178 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -415,12 +415,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 6eedca9fe514..e8d1e3f815c0 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -18,7 +18,6 @@ import torch from torch import nn -from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack From e19d54360e92f59846985ba4cd8274bb306ae205 Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Wed, 27 Aug 2025 22:10:37 +0100 Subject: [PATCH 38/44] remove fx_compatible (deprecated) --- tests/models/apertus/test_modeling_apertus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index 080051f19676..4369f47d2967 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -75,7 +75,6 @@ class ApertusModelTest(CausalLMModelTest, unittest.TestCase): ) test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez model_tester_class = ApertusModelTester rotary_embedding_layer = ApertusRotaryEmbedding # Enables RoPE tests if set From 69c46ed59a27ca336eb269adb11a892e4ebd386d Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Wed, 27 Aug 2025 23:21:07 +0100 Subject: [PATCH 39/44] remove `rotary_embedding_layer` As the tests are written for a config without default scaling (which is not the case in Apertus) - besides, rope scaling is tested in other models so it's all safe. --- tests/models/apertus/test_modeling_apertus.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index 4369f47d2967..424a86602bdd 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -42,7 +42,6 @@ ApertusForTokenClassification, ApertusModel, ) - from transformers.models.apertus.modeling_apertus import ApertusRotaryEmbedding class ApertusModelTester(CausalLMModelTester): @@ -76,7 +75,6 @@ class ApertusModelTest(CausalLMModelTest, unittest.TestCase): test_headmasking = False test_pruning = False model_tester_class = ApertusModelTester - rotary_embedding_layer = ApertusRotaryEmbedding # Enables RoPE tests if set # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer From 864c4ddc46501462127f822678f086406ef1e6fd Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Wed, 27 Aug 2025 23:26:09 +0100 Subject: [PATCH 40/44] fully removing `Mask4DTestHard` class Not needed (for now) --- tests/models/apertus/test_modeling_apertus.py | 250 ------------------ 1 file changed, 250 deletions(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index 424a86602bdd..e033ad08c19f 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -89,253 +89,3 @@ class ApertusModelTest(CausalLMModelTest, unittest.TestCase): @slow class ApertusIntegrationTest(unittest.TestCase): pass - - -@slow -@require_torch_accelerator -class Mask4DTestHard(unittest.TestCase): - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - def setUp(self): - cleanup(torch_device, gc_collect=True) - model_name = "swiss-ai/Apertus-8B" - self.model_dtype = torch.bfloat16 - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = ApertusForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_separate = [template.format(x) for x in items] # 3 separate lines - batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated - - input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) - input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) - - mask_shared_prefix = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, - ) - - position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) - - # building custom positions ids based on custom mask - position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) - # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - - # inverting the mask - min_dtype = torch.finfo(self.model_dtype).min - mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype - - return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix - - def test_stacked_causal_mask(self): - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - outs_1b = self.model.forward( - input_1b, - attention_mask=mask_1b, - position_ids=position_ids_1b, - past_key_values=past_key_values_a, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) - - def test_stacked_causal_mask_static_cache(self): - """same as above but with StaticCache""" - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - padded_attention_mask = torch.nn.functional.pad( - input=mask_shared_prefix, - pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, - attention_mask=padded_attention_mask, - position_ids=position_ids_shared_prefix, - cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), - past_key_values=past_key_values, - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask_static_cache(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - # forward run for the first part of input - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - padded_mask_1a = torch.nn.functional.pad( - input=mask_1a, - pad=(0, max_cache_len - mask_1a.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - _ = self.model.forward( - input_1a, - attention_mask=padded_mask_1a, - position_ids=position_ids_1a, - cache_position=torch.arange(part_a, device=torch_device), - past_key_values=past_key_values, - ) - - # forward run for the second part of input - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - - padded_mask_1b = torch.nn.functional.pad( - input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 - ) - - outs_1b = self.model.forward( - input_1b, - attention_mask=padded_mask_1b, - position_ids=position_ids_1b, - cache_position=torch.arange( - part_a, - input_ids_shared_prefix.shape[-1], - device=torch_device, - ), - past_key_values=past_key_values, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) From e7d03ad1d01d29cae6c6fe7104039026aa3372dd Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Wed, 27 Aug 2025 23:55:00 +0100 Subject: [PATCH 41/44] switch to `dtype` instead of `torch_dtype` Following this: https://github.com/huggingface/transformers/pull/39782 --- docs/source/en/model_doc/apertus.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/apertus.md b/docs/source/en/model_doc/apertus.md index edf91a6e0c3b..798843f24ddd 100644 --- a/docs/source/en/model_doc/apertus.md +++ b/docs/source/en/model_doc/apertus.md @@ -42,7 +42,7 @@ from transformers import pipeline pipeline = pipeline( task="text-generation", model="swiss-ai/Apertus-8B", - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, device=0 ) pipeline("Plants create energy through a process known as") @@ -60,7 +60,7 @@ tokenizer = AutoTokenizer.from_pretrained( ) model = AutoModelForCausalLM.from_pretrained( "swiss-ai/Apertus-8B", - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" ) From c3944468b5868bd5b704c4b6f0870376dc2ce356 Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:35:09 +0100 Subject: [PATCH 42/44] remove unused imports --- tests/models/apertus/test_modeling_apertus.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index e033ad08c19f..77769c430e08 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -20,22 +20,18 @@ import unittest -from transformers import AutoTokenizer, StaticCache, is_torch_available +from transformers import is_torch_available from transformers.testing_utils import ( - cleanup, require_read_token, require_torch, require_torch_accelerator, slow, - torch_device, ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): - import torch - from transformers import ( ApertusConfig, ApertusForCausalLM, From 68c6defc470b82dac8eedbf3fe17c1609b89f0e7 Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:51:47 +0100 Subject: [PATCH 43/44] remove `cache_implementation="static"` --- docs/source/en/model_doc/apertus.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/apertus.md b/docs/source/en/model_doc/apertus.md index 798843f24ddd..670cf5c8a77b 100644 --- a/docs/source/en/model_doc/apertus.md +++ b/docs/source/en/model_doc/apertus.md @@ -66,7 +66,7 @@ model = AutoModelForCausalLM.from_pretrained( ) input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") -output = model.generate(**input_ids, cache_implementation="static") +output = model.generate(**input_ids) print(tokenizer.decode(output[0], skip_special_tokens=True)) ``` From 227f026d35b23d54087bdfddc56e28dad281719c Mon Sep 17 00:00:00 2001 From: Dhia Garbaya <84809366+dhia680@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:59:53 +0100 Subject: [PATCH 44/44] +Apertus to `docs/source/en/_toctree.yml` for the doc builder --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c62d3b763165..a1fe9212f7a7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -373,6 +373,8 @@ - sections: - local: model_doc/albert title: ALBERT + - local: model_doc/apertus + title: Apertus - local: model_doc/arcee title: Arcee - local: model_doc/bamba