-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[Core] Support loading GGUF model #5191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
1ffda2e
init gguf loading support
Isotr0py f3058b1
add gguf running support
Isotr0py 259d5b5
Fix numpy warning
Isotr0py 0035bdf
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py 995f98e
fix gguf load format
Isotr0py d116f2e
add more example prompts
Isotr0py f387f9e
update requirements.txt
Isotr0py 516552a
add dequant runtime
Isotr0py de5950d
remove debug code
Isotr0py 5bda5f0
format code
Isotr0py 980c018
update gguf example
Isotr0py f969b36
Merge branch 'main' into gguf
Isotr0py e99f521
Merge branch 'vllm-project:main' into gguf
Isotr0py 9d36996
Fix requirements.txt
Isotr0py 3a18502
rename ggml -> gguf
Isotr0py e194e28
auto detect gguf quant and format
Isotr0py 164b643
use autotokenizer to load gguf tokenizer
Isotr0py b055fb3
Add runtime dequantization for all layers
Isotr0py c93c44e
Merge branch 'main' into gguf
Isotr0py 8960270
port gguf cuda kernel
Isotr0py 1d0c6a4
add qwen2 support and gguf mmq for linear
Isotr0py 957faec
remove transformers load_dequant_gguf_tensor
Isotr0py 4555cf5
reorder gguf weight iterator
Isotr0py 7f7af2b
fix imatrix
Isotr0py 87078be
fix imatrix
Isotr0py ca39edf
refactor, fix column parallel
Isotr0py cf03757
refactor gguf_kernel and remove dmmv
Isotr0py c2524a8
refactor to unmerge weights for gguf
Isotr0py 446c64a
revert get_quantization_config
Isotr0py dc43654
revert get_quantization_config
Isotr0py 2861670
revert qwen2
Isotr0py 1622966
add quant vocal embeddings
Isotr0py c4d4f96
support quantized parallelhead
Isotr0py 9a99252
revert qwen2
Isotr0py bc1ab48
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py 3fad5bd
rebase gguf support
Isotr0py 409bed3
format code
Isotr0py b38bd1d
format code
Isotr0py 3586f12
support qwen2 gguf
Isotr0py 8a56d55
Merge branch 'main' into gguf
Isotr0py defe23f
fix gguf loader
Isotr0py 6c4300e
add gguf test
Isotr0py 266447b
format code
Isotr0py d5a7e2f
format code
Isotr0py 6026e02
remove archs<7.0 in cmakelists
Isotr0py 9dc8794
fix a typo
Isotr0py ef9b8a3
format code
Isotr0py b708ce6
format code
Isotr0py be51a27
fix failed model test
Isotr0py 1bd7d16
Merge branch 'vllm-project:main' into gguf
Isotr0py c155f74
Merge branch 'main' into gguf
Isotr0py e49f96e
add imatrix and qwen2 test
Isotr0py af0c051
reorganize gguf kernel
Isotr0py 0ce3961
exclude gguf copied code
Isotr0py e599b07
refactor to merge weights
Isotr0py 25dcc08
forma code
Isotr0py eed9a23
format code
Isotr0py 6e5330d
import gguf
Isotr0py e5a61be
import gguf
Isotr0py 64c5375
refactor quantized vocal embedding
Isotr0py 86ef2b5
optimize docs
Isotr0py 7ccfacb
add docs
Isotr0py 28dc7b6
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py 1b39fbc
fix llama embed quant
Isotr0py d413f60
Fix CUDA graph with gguf
Isotr0py 1868a94
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py b4e2f29
fix quant embeddings
Isotr0py 2cc6753
Merge branch 'main' into gguf
mgoin db54a19
Fix embedding method and format
mgoin 2549c3e
Cleanup linear comments
mgoin 0890fa9
move gguf to cuda requirements
Isotr0py 5166ac9
raise error for gguf when tp>1
Isotr0py 26349db
Merge branch 'main' into gguf
mgoin 73da240
Last round of cleanup
mgoin 1c83d63
Improve qweight_type size calc
mgoin 1139e7b
Fix lm head tests
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
|
|
||
| def run_gguf_inference(model_path): | ||
| # Sample prompts. | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] | ||
| # Create a sampling params object. | ||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
|
||
| # Create an LLM. | ||
| llm = LLM(model=model_path, | ||
| tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
| load_format="gguf", | ||
| quantization="ggml") | ||
|
|
||
| outputs = llm.generate(prompts, sampling_params) | ||
| # Print the outputs. | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" | ||
| filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" | ||
| model = hf_hub_download(repo_id, filename=filename) | ||
| run_gguf_inference(model) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import torch | ||
| from torch.nn.parameter import Parameter | ||
|
|
||
| from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizationConfig) | ||
| from vllm.model_executor.utils import set_weight_attrs | ||
|
|
||
|
|
||
| class GGMLConfig(QuantizationConfig): | ||
Isotr0py marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Config class for GGML.""" | ||
|
|
||
| def __init__(self, ) -> None: | ||
| pass | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ("GGMLConfig()") | ||
|
|
||
| def get_name(self) -> str: | ||
| return "ggml" | ||
|
|
||
| def get_supported_act_dtypes(self) -> List[torch.dtype]: | ||
| return [torch.half, torch.bfloat16] | ||
|
|
||
| def get_min_capability(self) -> int: | ||
| return 70 | ||
|
|
||
| @classmethod | ||
| def get_config_filenames(cls) -> List[str]: | ||
| return [] # no extra configs. | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: Dict[str, Any]) -> "GGMLConfig": | ||
| return cls() | ||
|
|
||
| def get_quant_method( | ||
| self, layer: torch.nn.Module) -> Optional["GGMLLinearMethod"]: | ||
| if isinstance(layer, LinearBase): | ||
| return GGMLLinearMethod(self) | ||
| return None | ||
|
|
||
| def get_scaled_act_names(self) -> List[str]: | ||
| return [] | ||
|
|
||
|
|
||
| class GGMLLinearMethod(LinearMethodBase): | ||
| """Linear method for GGML. | ||
|
|
||
| Args: | ||
| quant_config: The GGML quantization config. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config: GGMLConfig): | ||
| self.quant_config = quant_config | ||
| self.block_size = 32 | ||
|
|
||
| def create_weights(self, layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: List[int], input_size: int, | ||
| output_size: int, params_dtype: torch.dtype, | ||
| **extra_weight_attrs): | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| quants = Parameter(torch.empty(output_size_per_partition, | ||
| input_size_per_partition, | ||
| dtype=torch.int8), | ||
| requires_grad=False) | ||
| set_weight_attrs(quants, {"input_dim": 1, "output_dim": 0}) | ||
| set_weight_attrs(quants, extra_weight_attrs) | ||
| layer.register_parameter("quants", quants) | ||
Isotr0py marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| scales = Parameter( | ||
| torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition // self.block_size, | ||
| dtype=params_dtype, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs(scales, { | ||
| "input_dim": 1, | ||
| "output_dim": 0, | ||
| "ggml_scales": True | ||
| }) | ||
| set_weight_attrs(scales, extra_weight_attrs) | ||
| layer.register_parameter("scales", scales) | ||
|
|
||
| def apply(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| # dequantized for Q4_0 and Q8_0 | ||
| shape = layer.quants.shape | ||
| out = layer.quants.reshape(-1, self.block_size) * layer.scales.reshape( | ||
| -1, 1) | ||
Isotr0py marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| out = torch.matmul(x, out.reshape(shape).T) | ||
| if bias is not None: | ||
| out.add_(bias) | ||
Isotr0py marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return out | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # adapted from | ||
| # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/integrations/ggml.py | ||
| """ | ||
| Integration with GGML / The file is copied and adapted from https://github.com/99991/pygguf | ||
| with extra methods beings exposed | ||
| """ | ||
| import numpy as np | ||
| import torch | ||
| from transformers.integrations.ggml import (GGML_BLOCK_SIZES, GGML_TYPES, | ||
| load_dequant_gguf_tensor) | ||
|
|
||
|
|
||
| def convert_tensor_q4_0(data): | ||
| # C implementation | ||
| # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086 | ||
| # C struct definition | ||
| # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11 | ||
| block_size = GGML_BLOCK_SIZES["Q4_0"] | ||
| num_blocks = len(data) // block_size | ||
|
|
||
| data_f16 = np.frombuffer(data, | ||
| dtype=np.float16).reshape(num_blocks, | ||
| block_size // 2) | ||
| data_u8 = np.frombuffer(data, | ||
| dtype=np.uint8).reshape(num_blocks, block_size) | ||
|
|
||
| # The scales are stored on the first 2 bytes | ||
| # and the rest corresponds to the quants | ||
| scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) | ||
|
|
||
| # the rest of the bytes corresponds to the quants | ||
| # - we discard the first two bytes | ||
| quants = data_u8[:, 2:] | ||
|
|
||
| ql = (quants[:, :] & 0xF).astype(np.int8) - 8 | ||
| qr = (quants[:, :] >> 4).astype(np.int8) - 8 | ||
|
|
||
| # Use hstack | ||
| quants = np.hstack([ql, qr]) | ||
|
|
||
| return scales, quants | ||
|
|
||
|
|
||
| def convert_tensor_q8_0(data): | ||
| # C struct definition | ||
| # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 | ||
| block_size = GGML_BLOCK_SIZES["Q8_0"] | ||
| num_blocks = len(data) // block_size | ||
|
|
||
| scales = (np.frombuffer(data, dtype=np.float16).reshape( | ||
| num_blocks, 1 + 16)[:, :1].astype(np.float32)) | ||
| quants = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, | ||
| 2:] | ||
|
|
||
| return scales, quants | ||
|
|
||
|
|
||
| def load_gguf_tensor(tensor): | ||
| shape, ggml_type, data = tensor.shape, tensor.tensor_type, tensor.data | ||
|
|
||
| scales = None | ||
| if ggml_type == GGML_TYPES["Q8_0"] and "blk" in tensor.name: | ||
| scales, quants = convert_tensor_q8_0(data) | ||
| elif ggml_type == GGML_TYPES["Q4_0"] and "blk" in tensor.name: | ||
| scales, quants = convert_tensor_q4_0(data) | ||
| else: | ||
| quants = load_dequant_gguf_tensor(shape, ggml_type, data) | ||
| quants = torch.from_numpy(quants.copy()) | ||
| return scales, quants | ||
|
|
||
| scales_shape = (int(shape[0] // 32), shape[1]) | ||
| scales = scales.reshape(scales_shape[::-1]) | ||
| quants = quants.reshape(shape[::-1]) | ||
| scales = torch.from_numpy(scales.copy()) | ||
| quants = torch.from_numpy(quants.copy()) | ||
| return scales, quants |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.