Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c47ca6a
fix conseq onehsot
Dec 11, 2024
4492e53
fix logic in is_model_quat_from_path
Dec 11, 2024
5ba651a
Merge branch 'main' into fix-test-conseq-oneshot
Dec 23, 2024
0e141d8
Merge branch 'main' into fix-test-conseq-oneshot
Dec 23, 2024
8a986d4
Merge branch 'main' into fix-test-conseq-oneshot
Jan 10, 2025
45f7163
Merge branch 'main' into fix-test-conseq-oneshot
Jan 10, 2025
a8d1178
Merge branch 'main' into fix-test-conseq-oneshot
Jan 10, 2025
7ccd115
Merge branch 'main' into fix-test-conseq-oneshot
dsikka Jan 10, 2025
59fa764
rahul comments
Jan 10, 2025
d2b0955
Merge branch 'fix-test-conseq-oneshot' of github.com:vllm-project/llm…
Jan 10, 2025
76b658d
Merge branch 'main' into fix-test-conseq-oneshot
Jan 13, 2025
2be6306
Merge branch 'main' into fix-test-conseq-oneshot
dsikka Jan 15, 2025
78a12ee
Merge branch 'main' into fix-test-conseq-oneshot
Jan 20, 2025
c601bb3
comment
Jan 20, 2025
ebbf182
comment
Jan 20, 2025
78ea0a8
comment
Jan 20, 2025
7d3b390
Merge branch 'main' into fix-test-conseq-oneshot
Jan 20, 2025
adbee68
comment
Jan 20, 2025
11fd1f0
Merge branch 'main' into fix-test-conseq-oneshot
rahul-tuli Jan 22, 2025
14d5c49
remove redudant code
Jan 22, 2025
6fcc92b
Merge branch 'fix-test-conseq-oneshot' of github.com:vllm-project/llm…
Jan 22, 2025
d47322e
Merge branch 'main' into fix-test-conseq-oneshot
Jan 22, 2025
e7a26b0
change from is_quantized_from_config to is_ct_quantized_from_config
Jan 22, 2025
2851766
Merge branch 'fix-test-conseq-oneshot' of github.com:vllm-project/llm…
Jan 22, 2025
597ce59
use head, not get, verified only download recipe
Jan 22, 2025
679b0a2
use create session
Jan 22, 2025
d584c16
comments
Jan 23, 2025
daa8f69
Merge branch 'main' into fix-test-conseq-oneshot
Jan 23, 2025
e4aace1
revert not using create_session
Jan 23, 2025
766fe0f
comment
Jan 23, 2025
faacff7
Merge branch 'fix-test-conseq-oneshot' of github.com:vllm-project/llm…
Jan 23, 2025
54f9f5c
Merge branch 'main' into fix-test-conseq-oneshot
Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PreTrainedModel,
set_seed,
)
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.core import pre_initialize_structure, reset_session
from llmcompressor.pytorch.model_load.helpers import (
Expand All @@ -52,7 +53,10 @@
from llmcompressor.transformers.sparsification.sparse_model import (
get_shared_processor_src,
)
from llmcompressor.transformers.utils.helpers import detect_last_checkpoint
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
is_model_ct_quantized_from_path,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model

Expand Down Expand Up @@ -224,6 +228,13 @@ def initialize_model_from_path(
"trust_remote_code": model_args.trust_remote_code_model,
}
# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
if is_model_ct_quantized_from_path(model_path):
model_kwargs["quantization_config"] = CompressedTensorsConfig(
run_compressed=False
)

model = AutoModelForCausalLM.from_pretrained(
model_path,
**model_kwargs,
Expand Down
105 changes: 104 additions & 1 deletion src/llmcompressor/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
"""

import os
from typing import TYPE_CHECKING, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

import requests
from huggingface_hub import HUGGINGFACE_CO_URL_HOME, hf_hub_download
from loguru import logger
from transformers import AutoConfig
from transformers.trainer_utils import get_last_checkpoint

if TYPE_CHECKING:
Expand All @@ -15,6 +19,7 @@
__all__ = [
"RECIPE_FILE_NAME",
"detect_last_checkpoint",
"is_model_ct_quantized_from_path",
]

RECIPE_FILE_NAME = "recipe.yaml"
Expand Down Expand Up @@ -54,3 +59,101 @@ def detect_last_checkpoint(
)

return last_checkpoint


def is_model_ct_quantized_from_path(path: str) -> bool:
"""
Determine if model from path is quantized based
on the config

:param path: path to the model or HF stub
:return: True if config contains quantization_config from the given path

"""
config = AutoConfig.from_pretrained(path)
if config is not None:
if (
hasattr(config, "quantization_config")
and config.quantization_config["quant_method"] == "compressed-tensors"
):
return True
return False


def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
"""
Infer the recipe from the model_path.

:param model_path: The path to the model to load. It can be one of the following:
- a path to the model directory
- a path to the model file
- Hugging face model ID
:return: The path to the recipe file if found, None otherwise.
"""
model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path

if os.path.isdir(model_path) or os.path.isfile(model_path):
# Model path is a local path to the model directory or file
model_path = (
os.path.dirname(model_path) if os.path.isfile(model_path) else model_path
)
recipe = os.path.join(model_path, RECIPE_FILE_NAME)

if os.path.isfile(recipe):
logger.info(f"Found recipe in the model_path: {recipe}")
return recipe
logger.debug(f"No recipe found in the model_path: {model_path}")
return None

# If the model path is a Hugging Face model ID
recipe = recipe_from_huggingface_model_id(hf_stub=model_path)

if recipe is None:
logger.info("Failed to infer the recipe from the model_path")

return recipe


def recipe_from_huggingface_model_id(
hf_stub: str, recipe_file_name: str = RECIPE_FILE_NAME
) -> Optional[str]:
"""
Attempts to download the recipe from the Hugging Face model ID.

:param hf_stub: Assumed to be the Hugging Face model ID.
:param recipe_file_name: The name of the recipe file to download.
Defaults to RECIPE_FILE_NAME.
:return: A tuple:
- The path to the recipe file if found, None otherwise.
- True if hf_stub is a valid Hugging Face model ID, False otherwise.
"""
model_id_url = os.path.join(HUGGINGFACE_CO_URL_HOME, hf_stub)
request = requests.head(model_id_url)

if request.status_code != 200:
logger.debug(
(
"hf_stub is not a valid Hugging Face model ID. ",
"Skipping recipe resolution.",
)
)
return None

try:
logger.info(
"Attempting to download a recipe ",
f"{hf_stub} " f"from {HUGGINGFACE_CO_URL_HOME}",
)
recipe = hf_hub_download(repo_id=hf_stub, filename=recipe_file_name)
logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.")
except Exception as e:
logger.error(
(
f"Unable to find recipe {recipe_file_name} "
f"for model ID: {hf_stub}: {e}."
"Skipping recipe resolution."
)
)
recipe = None

return recipe
39 changes: 30 additions & 9 deletions tests/llmcompressor/transformers/obcq/test_consecutive_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import pytest
import yaml
from parameterized import parameterized_class
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
from tests.testing_utils import parse_params, requires_gpu

CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs"
Expand All @@ -15,13 +18,15 @@


class TestConsecutiveRuns(unittest.TestCase):
quantization_config = CompressedTensorsConfig(run_compressed=False)

def _test_consecutive_runs(
self, tolerance: float, num_calibration_samples: int = 16
):
import math

from llmcompressor.core import active_session
from llmcompressor.pytorch.model_load.helpers import get_session_model
from llmcompressor.pytorch.model_load.helpers import initialize_recipe
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
from llmcompressor.transformers import oneshot
from llmcompressor.utils.pytorch import qat_active
Expand All @@ -36,19 +41,29 @@ def _test_consecutive_runs(
oneshot_device=self.device,
clear_sparse_session=False,
)
first_tiny_model = get_session_model()

first_model = AutoModelForCausalLM.from_pretrained(
self.output_first,
device_map="auto",
quantization_config=self.quantization_config,
)

layer_0_sparse = tensor_sparsity(
first_tiny_model.model.layers[0].self_attn.k_proj.weight
first_model.model.layers[0].self_attn.k_proj.weight
)
assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=tolerance)
assert qat_active(first_tiny_model)
assert qat_active(first_model)

session = active_session()
session_recipe = session.lifecycle.recipe_container.compiled_recipe
stages = [stage.group for stage in session_recipe.stages]
self.assertEqual(len(stages), 1)
session.reset()

recipe = infer_recipe_from_model_path(model_path=self.output_first)
if recipe:
initialize_recipe(model=first_model, recipe_path=recipe)

# reload saved model and up sparsity to 0.7
oneshot(
model=self.output_first,
Expand All @@ -57,15 +72,19 @@ def _test_consecutive_runs(
recipe=self.second_recipe,
output_dir=self.output_second,
oneshot_device=self.device,
clear_sparse_session=False,
)

second_tiny_model = get_session_model()
second_model = AutoModelForCausalLM.from_pretrained(
self.output_second,
device_map="auto",
quantization_config=self.quantization_config,
)

layer_0_sparse = tensor_sparsity(
second_tiny_model.model.layers[0].self_attn.k_proj.weight
second_model.model.layers[0].self_attn.k_proj.weight
)
assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance)
assert qat_active(second_tiny_model)
assert qat_active(second_model)

session = active_session()
session_recipe = session.lifecycle.recipe_container.compiled_recipe
Expand Down Expand Up @@ -119,7 +138,9 @@ def setUp(self):
from transformers import AutoModelForCausalLM

self.model = AutoModelForCausalLM.from_pretrained(
self.model, device_map=self.device
self.model,
device_map=self.device,
quantization_config=self.quantization_config,
)

self.output = "./oneshot_output"
Expand Down
Loading