Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchtune import config, generation, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import ChatFormat, InstructTemplate, Message
from torchtune.training import FullModelTorchTuneCheckpointer

logger = utils.get_logger("DEBUG")

Expand Down Expand Up @@ -44,12 +45,26 @@ def __init__(self, cfg: DictConfig) -> None:

def setup(self, cfg: DictConfig) -> None:
checkpointer = config.instantiate(cfg.checkpointer)

if self._quantization_mode is not None:
if not isinstance(checkpointer, FullModelTorchTuneCheckpointer):
raise ValueError(
"Quantization is only supported for models quantized and saved with the "
"FullModelTorchTuneCheckpointer - please ensure you have quantized your "
"model and are using the quantized weights!"
)
if "qat" in self._quantization_mode:
raise ValueError(
"You have specified a quantizer with 'QAT' - "
"QAT quantizers should only be used during quantization aware training "
"and when quantizing models. Please use the corresponding post-training "
"quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
)

if self._quantization_mode is None:
ckpt_dict = checkpointer.load_checkpoint()
else:
# weights_only needs to be False when loading a quantized model
# currently loading a quantized model is only supported with the
# FullModelTorchTuneCheckpointer
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)

self._model = self._setup_model(
Expand All @@ -69,8 +84,11 @@ def _setup_model(
if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)

model.load_state_dict(model_state_dict)
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
else:
model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
training.validate_expected_param_dtype(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def generate(
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens,
x=tokens.clone(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed as cudagraphs is complaining about tensors being overwritten from previous graphs.

mask=curr_masks,
temperature=temperature,
top_k=top_k,
Expand Down