Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 63 additions & 105 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class LinearBase(CustomOp):
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, tensor parallelism will be disabled for this layer.
"""

def __init__(
Expand All @@ -250,6 +251,7 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super().__init__()

Expand All @@ -269,6 +271,7 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp


@CustomOp.register("replicated_linear")
Expand All @@ -285,6 +288,7 @@ class ReplicatedLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""

def __init__(
Expand All @@ -298,26 +302,21 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]

super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)

# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
self.output_partition_sizes,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
Expand Down Expand Up @@ -373,73 +372,6 @@ def extra_repr(self) -> str:
return s


class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.

Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)

def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)

if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]

param[shard_offset:shard_offset + shard_size] = loaded_weight


@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
Expand All @@ -462,7 +394,9 @@ class ColumnParallelLinear(LinearBase):
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""

def __init__(
Expand All @@ -478,9 +412,13 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
Expand All @@ -497,7 +435,8 @@ def __init__(
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)

self.gather_output = gather_output

Expand Down Expand Up @@ -526,8 +465,6 @@ def __init__(
else:
self.register_parameter("bias", None)

self.tp_rank = get_tensor_model_parallel_rank()

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):

output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -568,13 +505,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
param.load_column_parallel_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)

def forward(
self, input_
Expand All @@ -598,7 +537,7 @@ def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", tp_size={self.tp_size}"
s += f", gather_output={self.gather_output}"
return s

Expand All @@ -625,6 +564,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""

def __init__(
Expand All @@ -639,10 +580,13 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)

assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
Expand All @@ -654,7 +598,8 @@ def __init__(
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -846,8 +791,6 @@ def weight_loader_v2(self,

assert loaded_shard_id < len(self.output_sizes)

tp_size = get_tensor_model_parallel_world_size()

if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
Expand All @@ -859,17 +802,19 @@ def weight_loader_v2(self,
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
block_n) // self.tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
block_n // self.tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = sum(
self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size

param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)


class QKVParallelLinear(ColumnParallelLinear):
Expand Down Expand Up @@ -897,6 +842,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""

def __init__(
Expand All @@ -912,6 +858,7 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
Expand All @@ -920,7 +867,8 @@ def __init__(
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
Expand All @@ -946,7 +894,8 @@ def __init__(
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)

def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
Expand Down Expand Up @@ -1007,10 +956,13 @@ def weight_loader_v2(self,
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight,
shard_id=0,
tp_rank=self.tp_rank)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
Expand All @@ -1034,7 +986,8 @@ def weight_loader_v2(self,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -1240,6 +1193,7 @@ class RowParallelLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""

def __init__(
Expand All @@ -1255,10 +1209,13 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
Expand All @@ -1269,7 +1226,8 @@ def __init__(
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
Expand Down Expand Up @@ -1345,18 +1303,18 @@ def weight_loader_v2(self, param: BasevLLMParameter,
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)

param.load_row_parallel_weight(loaded_weight=loaded_weight)
param.load_row_parallel_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)

def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
input_parallel = splitted_input[self.tp_rank].contiguous()

# Matrix multiply.
assert self.quant_method is not None
Expand Down
Loading