Skip to content

Commit 2f1581b

Browse files
author
George Ohashi
committed
Merge branch 'fix-test_compress-tensors-utils' of github.com:vllm-project/llm-compressor into fix-test_compress-tensors-utils
2 parents 83eac37 + c0a552a commit 2f1581b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+367
-330
lines changed

examples/big_models_with_accelerate/mult_gpus_int8_device_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from llmcompressor.transformers import oneshot
88
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
99

10-
MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
10+
MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
1111

1212
# adjust based off number of desired GPUs
1313
# reserve_for_hessians=True reserves memory which is required by
1414
# GPTQModifier and SparseGPTModifier
1515
device_map = calculate_offload_device_map(
16-
MODEL_ID, num_gpus=2, reserve_for_hessians=True, torch_dtype=torch.bfloat16
16+
MODEL_ID, num_gpus=1, reserve_for_hessians=True, torch_dtype=torch.bfloat16
1717
)
1818

1919
model = AutoModelForCausalLM.from_pretrained(

src/llmcompressor/modifiers/obcq/utils/helpers.py

Lines changed: 0 additions & 202 deletions
This file was deleted.

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple
33

44
import torch
5+
from compressed_tensors.utils.offload import is_module_offloaded
56
from loguru import logger
67
from torch.nn import Module
78

@@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):
282283

283284
@torch.no_grad()
284285
def smooth(module):
286+
offloaded = is_module_offloaded(module)
287+
if offloaded:
288+
module._hf_hook.pre_forward(module)
289+
285290
if module in balance_layers:
286291
module.weight.mul_(scales.view(1, -1))
287292
elif module == smooth_layer:
@@ -292,6 +297,9 @@ def smooth(module):
292297
if hasattr(module, "bias") and module.bias is not None:
293298
module.bias.div_(scales)
294299

300+
if offloaded:
301+
module._hf_hook.post_forward(module, None)
302+
295303
parent = get_fsdp_parent(mapping.smooth_name, model)
296304
if parent is not None:
297305
parent.apply(smooth)
@@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
318326
# get the channel-wise dynamic range for each layer to be balanced
319327
weight_scales = []
320328
for layer in balance_layers:
329+
offloaded = is_module_offloaded(layer)
330+
if offloaded:
331+
layer._hf_hook.pre_forward(layer)
332+
321333
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
322334
weight_scales.append(scale)
335+
336+
if offloaded:
337+
layer._hf_hook.post_forward(layer, None)
338+
323339
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]
324340

325341
# calculate the amount of smoothing to apply
@@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
329345
1 - self.smoothing_strength
330346
)
331347
scales = torch.where(weight_scales > 0.0, scales, activation_scales)
348+
332349
return scales

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from llmcompressor.core import active_session, create_session, pre_initialize_structure
1111
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
12+
from llmcompressor.typing import Processor
1213

1314
COMPLETED_STAGES_FILENAME = "completed_stages.json"
1415

@@ -92,15 +93,16 @@ def initialize_recipe(model: Module, recipe_path: str):
9293
def save_model_and_recipe(
9394
model: Module,
9495
save_path: str,
95-
tokenizer: Optional[Any] = None,
96+
processor: Optional[Processor] = None,
9697
save_safetensors: bool = False,
9798
save_compressed: bool = False,
9899
):
99100
"""
100-
Save a model, tokenizer and the currently loaded recipe to file
101+
Save a model, processor and the currently loaded recipe to file
102+
101103
:param model: pytorch model to save
102104
:param save_path: path to save output to
103-
:param tokenizer: model tokenizer to save
105+
:param processor: model processor or tokenizer to save
104106
:param save_safetensors: whether to save as safetensors or pickle (bin)
105107
:param save_compressed: whether to compress sparse weights on disk
106108
"""
@@ -111,8 +113,8 @@ def save_model_and_recipe(
111113
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
112114
)
113115

114-
if tokenizer is not None:
115-
tokenizer.save_pretrained(save_path)
116+
if processor is not None:
117+
processor.save_pretrained(save_path)
116118

117119
logger.info("Saving output to {}".format(os.path.abspath(save_path)))
118120

src/llmcompressor/transformers/finetune/data/base.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from compressed_tensors.registry import RegistryMixin
44
from datasets import Dataset, IterableDataset
55
from loguru import logger
6-
from transformers import AutoTokenizer
76

87
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
98
from llmcompressor.transformers.finetune.data.data_helpers import (
109
LABELS_MASK_VALUE,
1110
get_custom_datasets_from_path,
1211
get_raw_dataset,
1312
)
13+
from llmcompressor.typing import Processor
1414

1515

1616
class TextGenerationDataset(RegistryMixin):
@@ -30,10 +30,10 @@ def __init__(
3030
text_column: str,
3131
data_args: DataTrainingArguments,
3232
split: str,
33-
tokenizer: AutoTokenizer,
33+
processor: Processor,
3434
):
3535
self.text_column = text_column
36-
self.tokenizer = tokenizer
36+
self.processor = processor
3737
self.data_args = data_args
3838
self.raw_kwargs = data_args.raw_kwargs or {}
3939
self.split = split
@@ -50,20 +50,38 @@ def __init__(
5050
else:
5151
self.padding = False
5252

53-
if self.tokenizer:
53+
# get tokenizer
54+
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
55+
56+
if self.tokenizer is not None:
57+
# fill in pad token
5458
if not self.tokenizer.pad_token:
5559
self.tokenizer.pad_token = self.tokenizer.eos_token
5660

57-
# configure sequence length
58-
max_seq_length = data_args.max_seq_length
59-
model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length
60-
if self.tokenizer and max_seq_length > model_max_length:
61-
logger.warning(
62-
f"The max_seq_length passed ({max_seq_length}) is larger than "
63-
f"the maximum length for the model ({tokenizer.model_max_length}). "
64-
f"Using max_seq_length={tokenizer.model_max_length}."
61+
# configure sequence length
62+
max_seq_length = data_args.max_seq_length
63+
if data_args.max_seq_length > self.tokenizer.model_max_length:
64+
logger.warning(
65+
f"The max_seq_length passed ({max_seq_length}) is larger than "
66+
f"maximum length for model ({self.tokenizer.model_max_length}). "
67+
f"Using max_seq_length={self.tokenizer.model_max_length}."
68+
)
69+
self.max_seq_length = min(
70+
data_args.max_seq_length, self.tokenizer.model_max_length
71+
)
72+
73+
# configure padding
74+
self.padding = (
75+
False
76+
if self.data_args.concatenate_data
77+
else "max_length"
78+
if self.data_args.pad_to_max_length
79+
else False
6580
)
66-
self.max_seq_length = min(data_args.max_seq_length, model_max_length)
81+
82+
else:
83+
self.max_seq_length = None
84+
self.padding = False
6785

6886
def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset:
6987
"""

0 commit comments

Comments
 (0)