Skip to content

Commit 88fc1a5

Browse files
Merge pull request #2 from EduardDurech/v0.8.2
v0.8.2 vLLM + SwissLM
2 parents 72c8f1a + 304e7f4 commit 88fc1a5

File tree

5 files changed

+574
-0
lines changed

5 files changed

+574
-0
lines changed

tests/models/decoder_only/language/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"ehristoforu/Falcon3-MoE-2x7B-Insruct", # mixtral
8484
marks=[pytest.mark.cpu_model],
8585
)
86+
pytest.param("Saesara/swissai"), # swissai
8687
])
8788
@pytest.mark.parametrize("dtype", ["half"])
8889
@pytest.mark.parametrize("max_tokens", [32])

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def check_available_online(
209209
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
210210
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
211211
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
212+
"SwissAIForCausalLM": _HfExamplesInfo("Saesara/swissai"), # TODO test 1.5B model
212213
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
213214
trust_remote_code=True),
214215
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",

vllm/model_executor/layers/activation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,43 @@
1515
from vllm.utils import LazyDict
1616

1717

18+
@CustomOp.register("xielu")
19+
class XIELU(CustomOp):
20+
"""
21+
Applies the xIELU activation function
22+
23+
Shapes:
24+
x: (num_tokens, d) or (batch_size, seq_len, d)
25+
return: (num_tokens, d) or (batch_size, seq_len, d)
26+
"""
27+
28+
def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6):
29+
super().__init__()
30+
self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1.0).unsqueeze(0))
31+
self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - beta)) - 1.0).unsqueeze(0))
32+
self.beta = beta
33+
self.eps = torch.tensor(eps, dtype=torch.bfloat16, device='cuda')
34+
35+
if current_platform.is_cuda_alike():
36+
# TODO CUDA implementation under development, using forward_native for now
37+
self._forward_method = self.forward_native
38+
elif current_platform.is_cpu():
39+
self._forward_method = self.forward_native
40+
41+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
42+
# TODO optimize to precompute
43+
alpha_p = F.softplus(self.alpha_p)
44+
alpha_n = self.beta + F.softplus(self.alpha_n)
45+
return torch.where(
46+
x > 0,
47+
alpha_p * x * x + self.beta * x,
48+
alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x
49+
)
50+
51+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
52+
return
53+
54+
1855
@CustomOp.register("fatrelu_and_mul")
1956
class FatreluAndMul(CustomOp):
2057
"""An activation function for FATReLU.

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
106106
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
107107
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
108+
"SwissAIForCausalLM": ("swissai", "SwissAIForCausalLM"),
108109
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
109110
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
110111
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),

0 commit comments

Comments
 (0)