Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/models/language/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
pytest.param(
"allenai/OLMoE-1B-7B-0924-Instruct",
marks=[pytest.mark.cpu_model],
)
),
pytest.param("swiss-ai/Apertus-8B"), # apertus
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def check_available_online(
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B",
trust_remote_code=True),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
Expand Down
111 changes: 111 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict

logger = init_logger(__name__)


@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
Expand Down Expand Up @@ -363,6 +366,112 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""

def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
1).unsqueeze(0))
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())

self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401

self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph

self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance.")
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)

def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
self.beta * x,
)

def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, (
"XIELU CUDA object must not be None")
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.

Expand Down Expand Up @@ -426,6 +535,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
"xielu":
lambda: XIELU(),
})


Expand Down
Loading