Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, Optional, Union
from typing import Iterable, Literal, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)


def replace_linear_class(
linear: nn.Linear,
style: str,
style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear.

Quant config is not supported yet
Replace nn.Linear with one of vLLM's tensor parallel linear classes.

`quant_config` is not yet supported.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""

if not isinstance(style, str):
Expand All @@ -93,7 +102,10 @@ def replace_linear_class(
}.get(style)

if vllm_linear_cls is None:
raise ValueError(f"Unsupported parallel style value: {style}")
logger.warning(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized.", style)
return linear

class HFCompatibleLinear(vllm_linear_cls):
"""
Expand All @@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers backend.")

self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config

self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the dtype of the loaded model is handled by a context manager:

with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)

trust_remote_code=vllm_config.model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix

# MLP modifications
self.tensor_parallelize(self.model)
self.apply_base_model_tp_plan(self.model)

# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config.vocab_size, logit_scale)
self.sampler = get_sampler()

def log_replacement(self, name: str, old_module: nn.Module,
new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)

def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
"""
if (self.config.base_model_tp_plan is None
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")
Expand All @@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
log_replacement(qual_name, child_module, new_module)
else:
self.tensor_parallelize(child_module, prefix=qual_name)
self.apply_base_model_tp_plan(child_module, prefix=qual_name)

def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
Expand All @@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module):
org_num_embeddings=self.config.vocab_size,
quant_config=None,
)
self.log_replacement("input embedding",
self.model.get_input_embeddings(), new_module)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
self.model.set_input_embeddings(new_module)

def forward(
Expand Down