Skip to content

Commit a3e809c

Browse files
Apertus and XIELU
Co-authored-by: AllenHaoHuang <[email protected]> Signed-off-by: EduardDurech <[email protected]>
1 parent 321938e commit a3e809c

File tree

5 files changed

+690
-1
lines changed

5 files changed

+690
-1
lines changed

tests/models/language/generation/test_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@
9292
pytest.param(
9393
"allenai/OLMoE-1B-7B-0924-Instruct",
9494
marks=[pytest.mark.cpu_model],
95-
)
95+
),
96+
pytest.param("swiss-ai/Apertus-8B"), # apertus
9697
])
9798
@pytest.mark.parametrize("max_tokens", [32])
9899
@pytest.mark.parametrize("num_logprobs", [5])

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def check_available_online(
137137
# yapf: disable
138138
_TEXT_GENERATION_EXAMPLE_MODELS = {
139139
# [Decoder-only]
140+
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B",
141+
trust_remote_code=True),
140142
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
141143
trust_remote_code=True),
142144
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",

vllm/model_executor/layers/activation.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1212
get_tensor_model_parallel_world_size)
13+
from vllm.logger import init_logger
1314
from vllm.model_executor.custom_op import CustomOp
1415
from vllm.model_executor.utils import set_weight_attrs
1516
from vllm.platforms import current_platform
1617
from vllm.utils import LazyDict
1718

19+
logger = init_logger(__name__)
20+
1821

1922
@CustomOp.register("fatrelu_and_mul")
2023
class FatreluAndMul(CustomOp):
@@ -363,6 +366,110 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
363366
return self.forward_native(x)
364367

365368

369+
@CustomOp.register("xielu")
370+
class XIELU(CustomOp):
371+
"""
372+
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
373+
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
374+
Otherwise, we emit a single warning and use xIELU Python
375+
"""
376+
377+
def __init__(
378+
self,
379+
alpha_p_init: float = 0.8,
380+
alpha_n_init: float = 0.8,
381+
beta: float = 0.5,
382+
eps: float = -1e-6,
383+
dtype: torch.dtype = torch.bfloat16,
384+
with_vector_loads: bool = False,
385+
):
386+
super().__init__()
387+
self.alpha_p = nn.Parameter(
388+
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
389+
1).unsqueeze(0))
390+
self.alpha_n = nn.Parameter(
391+
torch.log(
392+
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
393+
1).unsqueeze(0))
394+
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
395+
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
396+
self.with_vector_loads = with_vector_loads
397+
# Temporary until xIELU CUDA fully implemented
398+
self._beta_scalar = float(self.beta.detach().cpu().float().item())
399+
self._eps_scalar = float(self.eps.detach().cpu().float().item())
400+
401+
self._xielu_cuda_obj = None
402+
try:
403+
import xielu.ops # noqa: F401
404+
405+
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
406+
msg = "Using experimental xIELU CUDA."
407+
try:
408+
from torch._dynamo import allow_in_graph
409+
410+
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
411+
msg += " Enabled torch._dynamo for xIELU CUDA."
412+
except Exception as err:
413+
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
414+
"this may result in slower performance.")
415+
self._xielu_cuda_fn = self._xielu_cuda
416+
logger.warning_once(msg)
417+
except Exception as err:
418+
logger.warning_once(
419+
"CUDA-fused xIELU not available (%s) –"
420+
" falling back to a Python version.\n"
421+
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
422+
str(err),
423+
)
424+
425+
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
426+
alpha_p = nn.functional.softplus(self.alpha_p)
427+
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
428+
return torch.where(
429+
x > 0,
430+
alpha_p * x * x + self.beta * x,
431+
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
432+
self.beta * x,
433+
)
434+
435+
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
436+
"""Firewall function to prevent torch.compile from seeing .item()"""
437+
original_shape = x.shape
438+
# CUDA kernel expects 3D tensors, reshape if needed
439+
while x.dim() < 3:
440+
x = x.unsqueeze(0)
441+
if x.dim() > 3:
442+
x = x.view(-1, 1, x.size(-1))
443+
if original_shape != x.shape:
444+
logger.warning_once(
445+
"Warning: xIELU input tensor expects 3 dimensions"
446+
" but got (shape: %s). Reshaping to (shape: %s).",
447+
original_shape,
448+
x.shape,
449+
)
450+
result = self._xielu_cuda_obj.forward(
451+
x,
452+
self.alpha_p,
453+
self.alpha_n,
454+
# Temporary until xIELU CUDA fully implemented ->
455+
# self.{beta,eps}.item()
456+
self._beta_scalar,
457+
self._eps_scalar,
458+
self.with_vector_loads,
459+
)
460+
return result.view(original_shape)
461+
462+
def forward(self, input: torch.Tensor) -> torch.Tensor:
463+
if self._xielu_cuda_obj is not None and input.is_cuda:
464+
if not torch._dynamo.is_compiling():
465+
return self._xielu_cuda_fn(input)
466+
else:
467+
logger.warning_once(
468+
"torch._dynamo is compiling, using Python version of xIELU."
469+
)
470+
return self._xielu_python(input)
471+
472+
366473
class ScaledActivation(nn.Module):
367474
"""An activation function with post-scale parameters.
368475
@@ -426,6 +533,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
426533
lambda: nn.Tanh(),
427534
"sigmoid":
428535
lambda: nn.Sigmoid(),
536+
"xielu":
537+
lambda: XIELU(),
429538
})
430539

431540

0 commit comments

Comments
 (0)