From 528b35885c17f29b9d6911510ccddb1264f668d1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 9 Aug 2024 15:33:53 +0000 Subject: [PATCH 1/7] remove gguf tp raise --- vllm/model_executor/layers/quantization/gguf.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a4e0a4d50960..9a4a9f122f31 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -39,9 +39,6 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": - if get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "GGUF quantization hasn't supported tensor parallelism yet.") return cls() def get_quant_method(self, layer: torch.nn.Module, From 462783e6ad567781fadaa91762162288c121eaeb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 14 Aug 2024 14:10:03 +0000 Subject: [PATCH 2/7] save draft --- vllm/model_executor/layers/linear.py | 16 ++++++++++++---- vllm/model_executor/layers/quantization/gguf.py | 3 +++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 646839ff303e..23dd027c459a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -503,9 +503,9 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] + shard_size = loaded_weight.shape[output_dim] // tp_size shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + loaded_shard_id // tp_size param.shard_id.append(loaded_shard_id) param.shard_size[loaded_shard_id] = loaded_weight.shape @@ -855,8 +855,13 @@ def weight_loader(self, param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + input_dim = getattr(param, "input_dim", None) input_size = loaded_weight.shape[input_dim] param_data = param_data.narrow(input_dim, 0, input_size) @@ -968,6 +973,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) # Special case for GGUF @@ -978,7 +984,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + weight_shape = list(loaded_weight.shape) + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(weight_shape, dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9a4a9f122f31..09c42ef627e3 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -116,6 +116,7 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: shard_size = getattr(layer.qweight, "shard_size", None) shard_id = getattr(layer.qweight, "shard_id", None) + print(shard_id, shard_size, layer.qweight.shape) if shard_id and shard_size: result = [] @@ -126,9 +127,11 @@ def apply(self, shard_weight = layer.qweight[ offset:offset + shard_size[id][0], :shard_size[id][1]].contiguous() + qweight_type = layer.qweight_type.shard_weight_type[id] result.append(_fuse_mul_mat(x, shard_weight, qweight_type)) offset += shard_size[id][0] + print(_fuse_mul_mat(x, shard_weight, qweight_type).shape) out = torch.cat(result, axis=1) else: qweight = layer.qweight From c8ffa77a3e649e3a4e4315617ec28d002369fda7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 14 Aug 2024 15:01:50 +0000 Subject: [PATCH 3/7] fix tp=2 inference --- vllm/model_executor/layers/linear.py | 13 +++++++++---- vllm/model_executor/layers/quantization/gguf.py | 3 --- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 23dd027c459a..1b095b78ae8f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -503,11 +503,16 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] // tp_size - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id // tp_size + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 09c42ef627e3..9a4a9f122f31 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -116,7 +116,6 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: shard_size = getattr(layer.qweight, "shard_size", None) shard_id = getattr(layer.qweight, "shard_id", None) - print(shard_id, shard_size, layer.qweight.shape) if shard_id and shard_size: result = [] @@ -127,11 +126,9 @@ def apply(self, shard_weight = layer.qweight[ offset:offset + shard_size[id][0], :shard_size[id][1]].contiguous() - qweight_type = layer.qweight_type.shard_weight_type[id] result.append(_fuse_mul_mat(x, shard_weight, qweight_type)) offset += shard_size[id][0] - print(_fuse_mul_mat(x, shard_weight, qweight_type).shape) out = torch.cat(result, axis=1) else: qweight = layer.qweight From ebf53788fa48527790709df7c8e6d0621542b99e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 15 Aug 2024 05:23:51 +0000 Subject: [PATCH 4/7] code format --- vllm/model_executor/layers/linear.py | 5 +++-- vllm/model_executor/layers/quantization/gguf.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1b095b78ae8f..ca09894c2d4e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -990,8 +990,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) - weight_shape[input_dim] = weight_shape[input_dim] // tp_size - param.materialize(weight_shape, dtype=loaded_weight.dtype) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9a4a9f122f31..a6a1ed5b0dee 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) From ac9dbe3d707d120609a2c9f139a0b0159c8737d0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 16 Aug 2024 12:54:27 +0800 Subject: [PATCH 5/7] add gguf tp test --- tests/models/test_gguf.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index 5971179f0121..ce5843acc1a7 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -37,16 +37,22 @@ reason="gguf is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_models( + num_gpus_available, vllm_runner, example_prompts, model, dtype: str, max_tokens: int, num_logprobs: int, + tp_size: int, ) -> None: + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + original_model, gguf_model = model # Run unquantized model. @@ -54,7 +60,7 @@ def test_models( dtype=dtype, max_model_len=MAX_MODEL_LEN, enforce_eager=True, - tensor_parallel_size=1) as original_model: + tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) @@ -64,7 +70,7 @@ def test_models( dtype=dtype, max_model_len=MAX_MODEL_LEN, enforce_eager=True, - tensor_parallel_size=1) as gguf_model: + tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) From 3e3dc91b0a17ab0eb485120683c1aa488822fa6f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 18 Aug 2024 05:22:38 +0000 Subject: [PATCH 6/7] fix gguf test prompt format --- tests/models/test_gguf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index ce5843acc1a7..f3a5171eee46 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -20,7 +21,7 @@ MODELS = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), @@ -37,7 +38,7 @@ reason="gguf is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tp_size", [1, 2]) def test_models( @@ -55,11 +56,14 @@ def test_models( original_model, gguf_model = model + tokenizer = AutoTokenizer.from_pretrained(original_model) + example_prompts = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + for prompt in example_prompts] + # Run unquantized model. with vllm_runner(model_name=original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( @@ -69,7 +73,6 @@ def test_models( with vllm_runner(model_name=gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) From bc61e56d4672a5f5f9720ba78d557b1a859139ca Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 18 Aug 2024 13:41:27 +0800 Subject: [PATCH 7/7] code format --- tests/models/test_gguf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index f3a5171eee46..196cd88e039a 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -57,8 +57,13 @@ def test_models( original_model, gguf_model = model tokenizer = AutoTokenizer.from_pretrained(original_model) - example_prompts = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - for prompt in example_prompts] + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) # Run unquantized model. with vllm_runner(model_name=original_model,