-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Apertus #22810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apertus #22810
Conversation
v0.8.2 vLLM + SwissLM
* Update swissai.py Replaced LlamaConfig with SwissAIConfig Changed up_proj from RowParallelLinear to ColumnParallelLinear * `ColumnParallelLinear` import * `LLaMa` -> `SwissAI` --------- Co-authored-by: EduardDurech <[email protected]>
Bugfixes in swiss-ai/main
vllm uses float16, and recasts bfloat16 to float16
Removed unnecessary comments
compatible with torch.dynamo, passes tests matches HF and vLLM outputs
temporarily commenting out tests for other models
vllm xielu fix mirrors transformers with additional torch._dynamo.is_compiling() check
updating vllm
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the Apertus model, which features a new xIELU activation function and QK normalization. The implementation of the Apertus model and its registration are well-structured and align with the existing codebase. My review identifies a critical bug in the XIELU activation function that could cause a TypeError, and a high-severity issue regarding the use of print instead of a proper logger, which is a better practice for library code.
| return result.view(original_shape) | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| if self._xielu_cuda_obj is not None and input.is_cuda and not torch._dynamo.is_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential TypeError in the forward method. If torch._dynamo.allow_in_graph is not available or fails to be imported, self._xielu_cuda_fn will remain None. However, self._xielu_cuda_obj could be non-None if xielu.ops was imported successfully. In this case, the condition self._xielu_cuda_obj is not None would pass, and the code would attempt to call self._xielu_cuda_fn(input), which would result in None(input), raising a TypeError.
To fix this, the condition should check self._xielu_cuda_fn's availability instead of self._xielu_cuda_obj.
| if self._xielu_cuda_obj is not None and input.is_cuda and not torch._dynamo.is_compiling(): | |
| if self._xielu_cuda_fn is not None and input.is_cuda and not torch._dynamo.is_compiling(): |
| except Exception as err: | ||
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | ||
| except Exception as err: | ||
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | ||
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using print for warnings in a library is discouraged as it can interfere with the logging configuration of downstream applications. It's better to use the logging module for this.
Please replace the print calls with logger.warning. You'll need to add the following at the beginning of the file:
from vllm.logger import init_logger
logger = init_logger(__name__)| except Exception as err: | |
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | |
| except Exception as err: | |
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | |
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") | |
| except Exception as err: | |
| logger.warning(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | |
| except Exception as err: | |
| logger.warning(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | |
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the Apertus model, which includes a new xIELU activation function and QK normalization. The implementation looks solid, but I've identified a critical bug in the XIELU activation layer that could lead to a TypeError at runtime. My review includes a fix for this issue along with suggestions for improving logging and exception handling.
| self._xielu_cuda_obj = None | ||
| self._xielu_cuda_fn = None # Will be set if CUDA available | ||
| try: | ||
| import xielu.ops # noqa: F401 | ||
|
|
||
| self._xielu_cuda_obj = torch.classes.xielu.XIELU() | ||
| try: | ||
| from torch._dynamo import allow_in_graph | ||
| self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) | ||
| except Exception as err: | ||
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | ||
| except Exception as err: | ||
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | ||
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few issues in this block that could lead to runtime errors and maintenance difficulties:
- Critical Bug: If
torch._dynamo.allow_in_graphcannot be imported or fails (e.g., on older PyTorch versions),self._xielu_cuda_fnremainsNone. However, theforwardmethod will still attempt to call it if the CUDA objectself._xielu_cuda_objwas successfully created, leading to aTypeError. - Logging: Using
print()for warnings is not ideal in a library. It's better to use theloggingmodule for better control over verbosity and output streams. I recommend replacingprintwithlogger.warning. - Exception Handling: Catching the broad
Exceptioncan hide unexpected errors. It's better to catch more specific exceptions likeImportErrorandAttributeError.
I've provided a suggestion to fix the critical bug by setting a fallback for self._xielu_cuda_fn.
| self._xielu_cuda_obj = None | |
| self._xielu_cuda_fn = None # Will be set if CUDA available | |
| try: | |
| import xielu.ops # noqa: F401 | |
| self._xielu_cuda_obj = torch.classes.xielu.XIELU() | |
| try: | |
| from torch._dynamo import allow_in_graph | |
| self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) | |
| except Exception as err: | |
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | |
| except Exception as err: | |
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | |
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") | |
| self._xielu_cuda_obj = None | |
| self._xielu_cuda_fn = None # Will be set if CUDA available | |
| try: | |
| import xielu.ops # noqa: F401 | |
| self._xielu_cuda_obj = torch.classes.xielu.XIELU() | |
| self._xielu_cuda_fn = self._xielu_cuda | |
| try: | |
| from torch._dynamo import allow_in_graph | |
| self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) | |
| except Exception as err: | |
| print(f"Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance.") | |
| except Exception as err: | |
| print(f"CUDA-fused xIELU not available ({err}) - using Python implementation. " | |
| "Install with: pip install git+https://github.com/nickjbrowning/XIELU") |
Pre-release of Apertus from the Swiss AI Initiative
Main modifications from Llama
Corresponding transformers PR that the vLLM matches outputs with: huggingface/transformers#39381
The code passes the following tests with an early 1B checkpoint uploaded to Saesara/swissai. It is not included as part of the PR as the naming is still being finalized and will eventually be replaced with a checkpoint uploaded to https://huggingface.co/swiss-ai.
tests/models/registry.py
"ApertusForCausalLM": _HfExamplesInfo("Saesara/swissai", trust_remote_code=True),tests/models/language/generation/test_common.py
pytest.param("Saesara/swissai"),