-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
Support OLMo models. #2832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Support OLMo models. #2832
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
985c87d
initialize olmo model
Isotr0py ed801bc
add olmo loader
Isotr0py b470160
fix model running
Isotr0py 342394e
Fix activation error
Isotr0py e471d57
add olmo config
Isotr0py 5bd459e
fix a mistake
Isotr0py 070597c
fix weight tying
Isotr0py e6292eb
fix config import
Isotr0py ddc5b25
Merge branch 'vllm-project:main' into olmo
Isotr0py 8b4a899
format olmo code
Isotr0py b56ff03
Merge branch 'vllm-project:main' into olmo
Isotr0py 55d1251
Merge branch 'main' into olmo
zhuohan123 4ee4c16
change docs and readme
zhuohan123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,378 @@ | ||
| # coding=utf-8 | ||
| # Adapted from | ||
| # https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and | ||
| # https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py | ||
| # Copyright 2023 The vLLM team. | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
| # | ||
| # BSD 3-Clause License | ||
| # | ||
| # Copyright (c) 2022, Tri Dao, [email protected]. | ||
| # All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions are met: | ||
| # | ||
| # * Redistributions of source code must retain the above copyright notice, this | ||
| # list of conditions and the following disclaimer. | ||
| # | ||
| # * Redistributions in binary form must reproduce the above copyright notice, | ||
| # this list of conditions and the following disclaimer in the documentation | ||
| # and/or other materials provided with the distribution. | ||
| # | ||
| # * Neither the name of the copyright holder nor the names of its | ||
| # contributors may be used to endorse or promote products derived from | ||
| # this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
| """Inference-only OLMo model compatible with HuggingFace weights.""" | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from vllm.model_executor.input_metadata import InputMetadata | ||
| from vllm.model_executor.layers.attention import PagedAttention | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| LinearMethodBase, | ||
| QKVParallelLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.rotary_embedding import get_rope | ||
| from vllm.model_executor.layers.sampler import Sampler | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding | ||
| from vllm.model_executor.parallel_utils.parallel_state import ( | ||
| get_tensor_model_parallel_world_size, ) | ||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
| from vllm.model_executor.weight_utils import ( | ||
| default_weight_loader, | ||
| hf_model_weights_iterator, | ||
| ) | ||
| from vllm.sequence import SamplerOutput | ||
| from vllm.transformers_utils.configs.olmo import OLMoConfig | ||
|
|
||
| KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
|
||
|
|
||
| class SwiGLU(nn.Module): | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x, gate = x.chunk(2, dim=-1) | ||
| return F.silu(gate) * x | ||
|
|
||
| @property | ||
| def output_multiplier(self) -> float: | ||
| return 0.5 | ||
|
|
||
|
|
||
| class OlmoAttention(nn.Module): | ||
| """ | ||
| This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` | ||
| (plus another skip connection). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: OLMoConfig, | ||
| linear_method: Optional[LinearMethodBase] = None, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.hidden_size = config.d_model | ||
| assert config.d_model % config.n_heads == 0 | ||
| tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( | ||
| ) | ||
| self.total_num_heads = self.config.n_heads | ||
| assert self.total_num_heads % tensor_model_parallel_world_size == 0 | ||
| self.num_heads = self.total_num_heads // tensor_model_parallel_world_size | ||
| self.head_dim = self.hidden_size // self.total_num_heads | ||
|
|
||
| # Layer norms. | ||
| self.attn_norm = nn.LayerNorm(config.d_model, | ||
| elementwise_affine=False, | ||
| bias=False) | ||
| # Attention input projection. Projects x -> (q, k, v) | ||
| self.att_proj = QKVParallelLinear( | ||
| config.d_model, | ||
| self.head_dim, | ||
| self.total_num_heads, | ||
| bias=config.include_bias, | ||
| linear_method=linear_method, | ||
| ) | ||
|
|
||
| # Rotary embeddings. | ||
| if self.config.rope: | ||
| rope_theta = getattr(config, "rope_theta", 10000) | ||
| max_position_embeddings = getattr(config, | ||
| "max_position_embeddings", 8192) | ||
| self.rotary_emb = get_rope( | ||
| self.head_dim, | ||
| rotary_dim=self.head_dim, | ||
| max_position=max_position_embeddings, | ||
| base=rope_theta, | ||
| ) | ||
| self.scaling = self.head_dim**-0.5 | ||
| self.attn = PagedAttention(self.num_heads, | ||
| self.head_dim, | ||
| scale=self.scaling) | ||
|
|
||
| # Attention output projection. | ||
| self.attn_out = RowParallelLinear( | ||
| config.d_model, | ||
| config.d_model, | ||
| bias=config.include_bias, | ||
| linear_method=linear_method, | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| positions: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| kv_cache: KVCache, | ||
| input_metadata: InputMetadata, | ||
| ) -> torch.Tensor: | ||
| hidden_states = self.attn_norm(hidden_states) | ||
| qkv, _ = self.att_proj(hidden_states) | ||
| q, k, v = qkv.chunk(chunks=3, dim=-1) | ||
| if self.config.rope: | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| k_cache, v_cache = kv_cache | ||
| attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) | ||
| output, _ = self.attn_out(attn_output) | ||
| return output | ||
|
|
||
|
|
||
| class OlmoMLP(nn.Module): | ||
| """ | ||
| This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` | ||
| (plus another skip connection). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: OLMoConfig, | ||
| linear_method: Optional[LinearMethodBase] = None, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size | ||
| is not None else config.mlp_ratio * config.d_model) | ||
|
|
||
| # Layer norms. | ||
| self.ff_norm = nn.LayerNorm(config.d_model, | ||
| elementwise_affine=False, | ||
| bias=False) | ||
|
|
||
| # Feed-forward input projection. | ||
| self.ff_proj = ColumnParallelLinear( | ||
| config.d_model, | ||
| self.hidden_size, | ||
| bias=config.include_bias, | ||
| linear_method=linear_method, | ||
| ) | ||
|
|
||
| # Activation function. | ||
| # self.act = SiluAndMul() | ||
| # self.act.output_multiplier = 0.5 | ||
| self.act = SwiGLU() | ||
| assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 | ||
|
|
||
| # Feed-forward output projection. | ||
| self.ff_out = RowParallelLinear( | ||
| int(self.act.output_multiplier * self.hidden_size), | ||
| config.d_model, | ||
| bias=config.include_bias, | ||
| linear_method=linear_method, | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # Add feed-forward projection. | ||
| # shape: (batch_size, seq_len, d_model) | ||
| og_x = x | ||
| x = self.ff_norm(x) | ||
| x, _ = self.ff_proj(x) | ||
| x = self.act(x) | ||
| x, _ = self.ff_out(x) | ||
| x = og_x + x | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| class OlmoBlock(nn.Module): | ||
| """ | ||
| This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` | ||
| (plus another skip connection). | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| config: OLMoConfig, | ||
| linear_method: Optional[LinearMethodBase] = None): | ||
| super().__init__() | ||
| # Attention block. | ||
| self.attn = OlmoAttention(config, linear_method) | ||
|
|
||
| # MLP block. | ||
| self.mlp = OlmoMLP(config, linear_method) | ||
|
|
||
| def forward( | ||
| self, | ||
| positions: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| kv_cache: KVCache, | ||
| input_metadata: InputMetadata, | ||
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | ||
| # Attention block. | ||
| og_x = hidden_states | ||
| x = self.attn(positions, hidden_states, kv_cache, input_metadata) | ||
| x = x + og_x | ||
|
|
||
| # MLP block. | ||
| hidden_states = self.mlp(x) | ||
| return hidden_states | ||
|
|
||
|
|
||
| class OlmoModel(nn.Module): | ||
|
|
||
| def __init__(self, | ||
| config: OLMoConfig, | ||
| linear_method: Optional[LinearMethodBase] = None): | ||
| super().__init__() | ||
| self.config = config | ||
|
|
||
| self.transformer = nn.ModuleDict( | ||
| dict( | ||
| wte=VocabParallelEmbedding( | ||
| config.embedding_size or config.vocab_size, | ||
| config.d_model, | ||
| ), | ||
| ln_f=nn.LayerNorm(config.d_model, | ||
| elementwise_affine=False, | ||
| bias=False), | ||
| )) | ||
|
|
||
| blocks = [ | ||
| OlmoBlock(config, linear_method) for i in range(config.n_layers) | ||
| ] | ||
| if self.config.block_group_size > 1: | ||
| raise NotImplementedError("Block group size > 1 not supported yet") | ||
| else: | ||
| self.transformer.update({"blocks": nn.ModuleList(blocks)}) | ||
|
|
||
| if not config.weight_tying: | ||
| self.transformer.update({ | ||
| "ff_out": | ||
| ColumnParallelLinear( | ||
| config.d_model, | ||
| config.embedding_size or config.vocab_size, | ||
| bias=config.include_bias, | ||
| linear_method=linear_method, | ||
| ) | ||
| }) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: List[KVCache], | ||
| input_metadata: InputMetadata, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| :param input_ids: A tensor of shape `(batch_size, seq_len)`. | ||
| """ | ||
| # Get embeddings of input. | ||
| # shape: (batch_size, seq_len, d_model) | ||
| x = self.transformer.wte(input_ids) # type: ignore | ||
|
|
||
| # Apply blocks one-by-one. | ||
| for block_idx, block in enumerate(self.transformer.blocks): | ||
| # shape: (batch_size, seq_len, d_model) | ||
| x = block( | ||
| positions, | ||
| x, | ||
| kv_caches[block_idx], | ||
| input_metadata, | ||
| ) | ||
|
|
||
| # Apply final layer norm. | ||
| # shape: (batch_size, seq_len or 1, d_model) | ||
| x = self.transformer.ln_f(x) # type: ignore | ||
| return x | ||
|
|
||
|
|
||
| class OLMoForCausalLM(nn.Module): | ||
| """ | ||
| Extremely barebones HF model wrapper. | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| config: OLMoConfig, | ||
| linear_method: Optional[LinearMethodBase] = None): | ||
| super().__init__() | ||
| self.config = config | ||
| self.linear_method = linear_method | ||
| self.model = OlmoModel(config, linear_method) | ||
| self.lm_head_weight = (self.model.transformer.wte.weight | ||
| if config.weight_tying else | ||
| self.model.transformer.ff_out.weight) | ||
| self.sampler = Sampler(config.vocab_size) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: List[KVCache], | ||
| input_metadata: InputMetadata, | ||
| ) -> torch.Tensor: | ||
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=positions, | ||
| kv_caches=kv_caches, | ||
| input_metadata=input_metadata, | ||
| ) | ||
| return hidden_states | ||
|
|
||
| def sample( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> Optional[SamplerOutput]: | ||
| next_tokens = self.sampler(self.lm_head_weight, hidden_states, | ||
| sampling_metadata) | ||
| return next_tokens | ||
|
|
||
| def load_weights( | ||
| self, | ||
| model_name_or_path: str, | ||
| cache_dir: Optional[str] = None, | ||
| load_format: str = "auto", | ||
| revision: Optional[str] = None, | ||
| ): | ||
| params_dict = dict(self.named_parameters(remove_duplicate=False)) | ||
| for name, loaded_weight in hf_model_weights_iterator( | ||
| model_name_or_path, cache_dir, load_format, revision): | ||
| # attention | ||
| if ".att" in name: | ||
| name = name.replace(".att", ".attn.att") | ||
| # mlp | ||
| if ".ff" in name and "transformer.ff_out" not in name: | ||
| name = name.replace(".ff", ".mlp.ff") | ||
| # there is no bias in olmo | ||
| param = params_dict[name] | ||
| weight_loader = getattr(param, "weight_loader", | ||
| default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that
SwiGLUactivation used in olmo is different from theSiluAndMulin vllm: