-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Model] Add Gemma 2 #5908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add Gemma 2 #5908
Changes from 2 commits
7db6122
df2c007
a176803
a1ddec8
7fbcf48
8d5c6e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,3 +95,49 @@ def extra_repr(self) -> str: | |
| s = f"hidden_size={self.weight.data.size(0)}" | ||
| s += f", eps={self.variance_epsilon}" | ||
| return s | ||
|
|
||
|
|
||
| class GemmaRMSNorm(CustomOp): | ||
| """RMS normalization for Gemma. | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Two differences from the above RMSNorm: | ||
| 1. x * (1 + w) instead of x * w. | ||
| 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size: int, | ||
| eps: float = 1e-6, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.weight = nn.Parameter(torch.zeros(hidden_size)) | ||
| self.variance_epsilon = eps | ||
|
|
||
| def forward_native( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should try decorating this with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was thinking about it or writing a CUDA kernel. Let's discuss this in another PR?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good -- agree it should be in a different PR :) |
||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| orig_dtype = x.dtype | ||
| if residual is not None: | ||
| x = x + residual | ||
| residual = x | ||
|
|
||
| x = x.float() | ||
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
| x = x * torch.rsqrt(variance + self.variance_epsilon) | ||
| # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) | ||
| # See https://github.com/huggingface/transformers/pull/29402 | ||
| x = x * (1.0 + self.weight.float()) | ||
| x = x.to(orig_dtype) | ||
| return x if residual is None else (x, residual) | ||
|
|
||
| def forward_cuda( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| # TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. | ||
| return self.forward_native(x, residual) | ||
Uh oh!
There was an error while loading. Please reload this page.