-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Apertus #22810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apertus #22810
Changes from all commits
304e7f4
88fc1a5
f03bedf
0d79f60
1323808
0457a41
2cf6bbf
6e44971
9c56ab9
99757fa
d225c42
fa815e6
cc49328
93ac959
40f8c36
0bf40d9
ae7df9b
37d11f7
55453ae
bc769a7
c9443f2
ce36321
4dae158
35ce6a2
e005580
6241bff
9da245d
cb326a1
4268ad8
7553f6a
874249d
89c05fa
55d7d8f
528f01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,9 +16,84 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.utils import LazyDict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @CustomOp.register("xielu") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class XIELU(CustomOp): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Otherwise, we emit a single warning and use xIELU Python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_p_init: float = 0.8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_n_init: float = 0.8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| beta: float = 0.5, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eps: float = -1e-6, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with_vector_loads: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alpha_p = nn.Parameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1).unsqueeze(0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alpha_n = nn.Parameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.log(torch.exp(torch.tensor(alpha_n_init - beta)) - 1).unsqueeze(0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Register beta and eps as buffers (fixed tensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.register_buffer('beta', torch.tensor(beta), persistent=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.register_buffer('eps', torch.tensor(eps), persistent=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.with_vector_loads = with_vector_loads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._xielu_cuda_obj = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._xielu_cuda_fn = None # Will be set if CUDA available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import xielu.ops # noqa: F401 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._xielu_cuda_obj = torch.classes.xielu.XIELU() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch._dynamo import allow_in_graph | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as err: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as err: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+46
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few issues in this block that could lead to runtime errors and maintenance difficulties:
I've provided a suggestion to fix the critical bug by setting a fallback for
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_p = F.softplus(self.alpha_p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_n = self.beta + F.softplus(self.alpha_n) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.where( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x > 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_p * x * x + self.beta * x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Firewall function to prevent torch.compile from seeing .item() calls""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_shape = x.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # CUDA kernel expects 3D tensors, reshape if needed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while x.dim() < 3: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if x.dim() > 3: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.view(-1, 1, x.size(-1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = self._xielu_cuda_obj.forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alpha_p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alpha_n, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.beta.item(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.eps.item(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.with_vector_loads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result.view(original_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._xielu_cuda_obj is not None and input.is_cuda and not torch._dynamo.is_compiling(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential To fix this, the condition should check
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._xielu_cuda_fn(input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._xielu_python(input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @CustomOp.register("fatrelu_and_mul") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class FatreluAndMul(CustomOp): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """An activation function for FATReLU. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The function computes x -> FATReLU(x[:d]) * x[d:] where | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| d = x.shape[-1] // 2. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
printfor warnings in a library is discouraged as it can interfere with the logging configuration of downstream applications. It's better to use theloggingmodule for this.Please replace the
printcalls withlogger.warning. You'll need to add the following at the beginning of the file: