Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/features/quantization/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
:::

GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path

```console
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
```

You can also use the GGUF model directly through the LLM entrypoint:

```python
Expand Down
18 changes: 14 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3
from vllm.transformers_utils.utils import check_gguf_file, is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname)

Expand Down Expand Up @@ -229,6 +229,7 @@ def __init__(
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
self.hf_config_path = hf_config_path
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
Expand Down Expand Up @@ -321,8 +323,16 @@ def __init__(
if self.enable_sleep_mode and not current_platform.is_cuda():
raise ValueError("Sleep mode is only supported on CUDA devices.")

hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format)
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)
if check_gguf_file(model) and hf_config.torch_dtype != torch.float16:
logger.warning(
"GGUF requires model dtype to be float16,"
" you are trying to run with %s"
" the dtype will be automatically changed to float16",
hf_config.torch_dtype)
hf_config.torch_dtype = torch.float16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just remove torch.bfloat16 (it's added by mistake) in gguf.py:

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]

So that we don't need extra dtype handling here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
Expand Down Expand Up @@ -947,7 +957,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
hf_config_path: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
Expand Down Expand Up @@ -262,6 +263,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=nullable_str,
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
Expand Down Expand Up @@ -1075,6 +1082,7 @@ def create_model_config(self) -> ModelConfig:

return ModelConfig(
model=self.model,
hf_config_path=self.hf_config_path,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
Expand Down
22 changes: 21 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Optional, Tuple

import torch
from torch.nn.parameter import UninitializedParameter

import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -504,7 +505,12 @@ def weight_loader(self, param: torch.nn.Parameter,
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
Expand All @@ -514,6 +520,20 @@ def weight_loader(self, param: torch.nn.Parameter,
if is_transposed:
shard_dim = int(not shard_dim)

full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[
shard_dim] // get_tensor_model_parallel_world_size()
param.materialize(final_shape, dtype=loaded_weight.dtype)

expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,23 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
# Special case for GGUF

is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param.size() == loaded_weight.size()
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)

def forward(self,
Expand Down
128 changes: 127 additions & 1 deletion vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand Down Expand Up @@ -49,6 +52,8 @@ def get_quant_method(self, layer: torch.nn.Module,
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return None


Expand Down Expand Up @@ -184,6 +189,127 @@ def apply(self,
return out


class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.

Args:
quant_config: The GGUF quantization config.
"""

def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

assert intermediate_size_per_partition % 512 == 0, \
"GGUF requires a block size of 512, you are running with " \
f"{intermediate_size_per_partition}, please " \
"adjust your tensor parallel size"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to add this strict assertion here? I think there should be no block size limitation here since we will do padding in the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, the input is padded but the weights are not, so we can read weights out of memory

Copy link
Member

@Isotr0py Isotr0py Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I can still generate reasonable outputs on DeepSeek-V2-Lite-Chat.IQ4_XS.gguf after removing this assertion.

$ VLLM_MLA_DISABLE=1 python examples/offline_inference/basic/basic.py
INFO 02-25 14:43:11 [__init__.py:207] Automatically detected platform cuda.
INFO 02-25 14:43:12 [config.py:208] Replacing legacy 'type' key with 'rope_type'
WARNING 02-25 14:43:12 [config.py:330] GGUF requires model dtype to be float16, you are trying to run with torch.bfloat16 the dtype will be automatically changed to float16
INFO 02-25 14:43:21 [config.py:579] This model supports multiple tasks: {'classify', 'embed', 'reward', 'score', 'generate'}. Defaulting to 'generate'.
WARNING 02-25 14:43:21 [config.py:658] gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
WARNING 02-25 14:43:21 [cuda.py:95] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 02-25 14:43:23 [llm_engine.py:234] Initializing a V0 LLM engine (v0.1.dev4784+g18e5059) with config: model='../DeepSeek-V2-Lite-Chat.IQ4_XS.gguf', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-V2-Lite', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.GGUF, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gguf, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=../DeepSeek-V2-Lite-Chat.IQ4_XS.gguf, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}, use_cached_outputs=False, 
INFO 02-25 14:43:24 [cuda.py:178] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 02-25 14:43:24 [cuda.py:226] Using XFormers backend.
INFO 02-25 14:43:34 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 02-25 14:43:34 [model_runner.py:1110] Starting to load model ../DeepSeek-V2-Lite-Chat.IQ4_XS.gguf...
WARNING 02-25 14:43:48 [utils.py:168] The model class DeepseekV2ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/nested/__init__.py:226: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
INFO 02-25 14:44:04 [model_runner.py:1117] Loading model weights took 8.0463 GB and 30.105321 seconds
INFO 02-25 14:45:42 [worker.py:267] Memory profiling takes 97.83 seconds
INFO 02-25 14:45:42 [worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.90) = 13.27GiB
INFO 02-25 14:45:42 [worker.py:267] model weights take 8.05GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 0.95GiB; the rest of the memory reserved for KV Cache is 4.24GiB.
INFO 02-25 14:45:43 [executor_base.py:111] # cuda blocks: 858, # CPU blocks: 809
INFO 02-25 14:45:43 [executor_base.py:116] Maximum concurrency for 4096 tokens per request: 3.35x
INFO 02-25 14:45:49 [llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 104.76 seconds
Processed prompts: 100%|██████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.02it/s, est. speed input: 6.63 toks/s, output: 16.33 toks/s]
Prompt: 'Hello, my name is', Generated text: ' Aiden. I was born on the 1st of April, 2'
Prompt: 'The president of the United States is', Generated text: ', obviously, a very high-profile job. But it’s also a'
Prompt: 'The capital of France is', Generated text: ' Paris, the city of lights and love. Paris is also a popular destination for'
Prompt: 'The future of AI is', Generated text: ' fascinating and full of potential. As AI technology advances, it will increasingly play a'

If there is out of memory allocation issue, this should be handled and fixed at kernel side by adding corresponding shape check there, especially each quantization type has different block_size exactly. (We can leave it to be done in a following PR, since it will modify the kernel)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the assertion for now, can confirm that I get reasonable outputs for block size of 256, might have ran the code against my custom kernel when evaluating it by mistake. My bad


tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
hidden_size)
#gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)

w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w13_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)

tensor_shape = (num_experts, intermediate_size_per_partition,
hidden_size)
#gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)

w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w2_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})

set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
topk_weights, topk_ids = FusedMoE.select_experts(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After #13795 got merged, FusedMoEMethod.apply got a new keyword argument activation, c.f. this diff at vllm/model_executor/layers/fused_moe/layer.py, which cause unexpected keyword argument error:

ERROR 02-26 22:24:02 [engine.py:409]     return forward_call(*args, **kwargs)
ERROR 02-26 22:24:02 [engine.py:409]   File "/home/jovyan/git/vllm/.venv/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 673, in forward
ERROR 02-26 22:24:02 [engine.py:409]     final_hidden_states = self.quant_method.apply(
ERROR 02-26 22:24:02 [engine.py:409] TypeError: GGUFMoEMethod.apply() got an unexpected keyword argument 'activation'

We may adapt what awq_marlin does after rebase:

Suggested change
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
topk_weights, topk_ids = FusedMoE.select_experts(
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks for catching and pointing to a solution!

hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]

out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)

expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states


class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.

Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,9 +1245,24 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
Expand All @@ -1258,10 +1273,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(config)
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code)
state_dict = dummy_model.state_dict()

gguf_to_hf_name_map = {}
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]

if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
Expand Down