diff --git a/recipes/generate.py b/recipes/generate.py index d7ea12cc2f..fea44ddace 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -15,6 +15,7 @@ from torchtune import config, generation, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message +from torchtune.training import FullModelTorchTuneCheckpointer logger = utils.get_logger("DEBUG") @@ -44,12 +45,26 @@ def __init__(self, cfg: DictConfig) -> None: def setup(self, cfg: DictConfig) -> None: checkpointer = config.instantiate(cfg.checkpointer) + + if self._quantization_mode is not None: + if not isinstance(checkpointer, FullModelTorchTuneCheckpointer): + raise ValueError( + "Quantization is only supported for models quantized and saved with the " + "FullModelTorchTuneCheckpointer - please ensure you have quantized your " + "model and are using the quantized weights!" + ) + if "qat" in self._quantization_mode: + raise ValueError( + "You have specified a quantizer with 'QAT' - " + "QAT quantizers should only be used during quantization aware training " + "and when quantizing models. Please use the corresponding post-training " + "quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer." + ) + if self._quantization_mode is None: ckpt_dict = checkpointer.load_checkpoint() else: # weights_only needs to be False when loading a quantized model - # currently loading a quantized model is only supported with the - # FullModelTorchTuneCheckpointer ckpt_dict = checkpointer.load_checkpoint(weights_only=False) self._model = self._setup_model( @@ -69,8 +84,11 @@ def _setup_model( if self._quantization_mode is not None: model = self._quantizer.quantize(model) model = model.to(device=self._device, dtype=self._dtype) - - model.load_state_dict(model_state_dict) + for k, v in model_state_dict.items(): + model_state_dict[k] = v.to(self._device) + model.load_state_dict(model_state_dict, assign=True) + else: + model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. training.validate_expected_param_dtype( diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index 9241b6061b..c2d60a7373 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -366,7 +366,7 @@ def generate( tokens, logits = custom_generate_next_token( model, input_pos=curr_input_pos, - x=tokens, + x=tokens.clone(), mask=curr_masks, temperature=temperature, top_k=top_k,