Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
77810b8
init swissai model
haeggee Feb 26, 2025
bcdaf70
AutoModelForCausalLM
haeggee Feb 26, 2025
53a3755
AutoModelForCausalLM mapping
haeggee Feb 26, 2025
7c648e7
qk norm and post ln optional
haeggee Feb 28, 2025
d9a923d
fix wrong shape of qk norm: megatron uses head_dim
haeggee Mar 2, 2025
f35ee01
automodel fixes
haeggee Mar 2, 2025
e6921f7
minor fix in forward
haeggee Mar 2, 2025
46ca1ae
fix rope validation to accept llama3 scaling
dhia680 May 30, 2025
994b1d7
`SwissAIForTokenClassification` support
EduardDurech May 23, 2025
8b38b5a
Align `SwissAI` to v4.52.4
EduardDurech Jun 8, 2025
0ffc9b9
Align `SwissAI` to v4.53.1
EduardDurech Jul 12, 2025
7793c87
Init CUDA xIELU
EduardDurech Jul 12, 2025
590957b
`SwissAI*`->`Apertus*`
EduardDurech Jul 12, 2025
353c6c0
ci fix
EduardDurech Jul 12, 2025
833f5fe
check_docstring ignore ApertusConfig
EduardDurech Jul 13, 2025
f0ec65c
Licensing and placeholder tests
EduardDurech Jul 13, 2025
1f4e715
Placeholder doc
EduardDurech Jul 13, 2025
cf12582
XIELU syntax
EduardDurech Jul 13, 2025
331fc0d
`_xielu_python` optimization
EduardDurech Jul 13, 2025
2728d3c
Fix xIELU
EduardDurech Jul 13, 2025
d0d42cd
[tmp] `{beta,eps}` persistent=False
EduardDurech Aug 14, 2025
543b343
Modular `Apertus`
EduardDurech Aug 14, 2025
4d436d0
CUDA xIELU logging
EduardDurech Aug 14, 2025
35d6bb3
Merge upstream/main into model/apertus
EduardDurech Aug 14, 2025
e5ec231
ci fix
EduardDurech Aug 14, 2025
1de44fd
ci fix
EduardDurech Aug 14, 2025
9c0cb61
ci fix
EduardDurech Aug 14, 2025
8f1c081
Update license
EduardDurech Aug 19, 2025
dad00ca
Update tests/models/apertus/test_modeling_apertus.py
EduardDurech Aug 19, 2025
cd029ab
`.utils.import_utils.is_torchdynamo_compiling`
EduardDurech Aug 19, 2025
250b43a
`Apertus` class ordering
EduardDurech Aug 19, 2025
c4b6d76
`past_key_value{->s}`, `make fix-copies`
EduardDurech Aug 19, 2025
9865539
ci fix
EduardDurech Aug 19, 2025
a7abf5e
Remove unused configuration parameters
EduardDurech Aug 24, 2025
273da51
`{beta,eps}` saved in checkpoint
EduardDurech Aug 24, 2025
29da453
`{beta,eps}` Temporarily on CPU
EduardDurech Aug 27, 2025
792b7de
Suggestions
EduardDurech Aug 27, 2025
c2b3de5
Merge branch 'main' into model/apertus
EduardDurech Aug 27, 2025
a5889da
ci fix
EduardDurech Aug 27, 2025
e19d543
remove fx_compatible (deprecated)
dhia680 Aug 27, 2025
69c46ed
remove `rotary_embedding_layer`
dhia680 Aug 27, 2025
864c4dd
fully removing `Mask4DTestHard` class
dhia680 Aug 27, 2025
e7d03ad
switch to `dtype` instead of `torch_dtype`
dhia680 Aug 27, 2025
c394446
remove unused imports
dhia680 Aug 27, 2025
68c6def
remove `cache_implementation="static"`
dhia680 Aug 27, 2025
227f026
+Apertus to `docs/source/en/_toctree.yml` for the doc builder
dhia680 Aug 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@
- sections:
- local: model_doc/albert
title: ALBERT
- local: model_doc/apertus
title: Apertus
- local: model_doc/arcee
title: Arcee
- local: model_doc/bamba
Expand Down
100 changes: 100 additions & 0 deletions docs/source/en/model_doc/apertus.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
<!--Copyright 2025 The HuggingFace Team and the Swiss AI Initiative. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
</div>
</div>

# Apertus

[Apertus](https://www.swiss-ai.org) is a family of large language models from the Swiss AI Initiative.

> [!TIP]
> Coming soon

The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
import torch
from transformers import pipeline

pipeline = pipeline(
task="text-generation",
model="swiss-ai/Apertus-8B",
dtype=torch.bfloat16,
device=0
)
pipeline("Plants create energy through a process known as")
```

</hfoption>
<hfoption id="AutoModel">

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
"swiss-ai/Apertus-8B",
)
model = AutoModelForCausalLM.from_pretrained(
"swiss-ai/Apertus-8B",
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")

output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
<hfoption id="transformers CLI">

```bash
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model swiss-ai/Apertus-8B --device 0
```

</hfoption>
</hfoptions>

## ApertusConfig

[[autodoc]] ApertusConfig

## ApertusModel

[[autodoc]] ApertusModel
- forward

## ApertusForCausalLM

[[autodoc]] ApertusForCausalLM
- forward

## ApertusForTokenClassification

[[autodoc]] ApertusForTokenClassification
- forward
96 changes: 96 additions & 0 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor, nn

from .utils import logging
from .utils.import_utils import is_torchdynamo_compiling


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -185,6 +186,100 @@ def __getitem__(self, key):
return cls(**kwargs)


class XIELUActivation(nn.Module):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010

If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""

def __init__(
self,
alpha_p_init=0.8,
alpha_n_init=0.8,
beta=0.5,
eps=-1e-6,
dtype=torch.bfloat16,
with_vector_loads=False,
):
super().__init__()
self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0)
)
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())

self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401

self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph

self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) – falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)

def _xielu_python(self, x: Tensor) -> Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
)

def _xielu_cuda(self, x: Tensor) -> Tensor:
"""Firewall function to prevent torch.compile from seeing .item() calls"""
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)

def forward(self, input: Tensor) -> Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not is_torchdynamo_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
return self._xielu_python(input)


ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
Expand All @@ -206,6 +301,7 @@ def __getitem__(self, key):
"swish": nn.SiLU,
"tanh": nn.Tanh,
"prelu": nn.PReLU,
"xielu": XIELUActivation,
}
ACT2FN = ClassInstantier(ACT2CLS)

Expand Down
32 changes: 32 additions & 0 deletions src/transformers/models/apertus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
#
# This code is based on HuggingFace's LLaMA implementation in this library.
# It has been modified from its original forms to accommodate the architectural
# differences made by the Swiss AI Initiative that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_apertus import *
from .modeling_apertus import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading