@@ -685,8 +685,14 @@ def _prepare_weights(self, model_name_or_path: str,
685685
686686 return hf_weights_files , matched_pattern == "*.safetensors"
687687
688+ def _hf_weight_iter (self , hf_weights_files , use_safetensors : bool ):
689+ if use_safetensors :
690+ return safetensors_weights_iterator (hf_weights_files )
691+ else :
692+ return pt_weights_iterator (hf_weights_files )
693+
688694 def _get_quantized_weights_iterator (
689- self , model_name_or_path : str , revision : Optional [str ]
695+ self , model_name_or_path : str , revision : Optional [str ], pre_quant : bool
690696 ) -> Tuple [Generator [Tuple [str , torch .Tensor ], None , None ], Dict [str ,
691697 Any ]]:
692698 """Get an iterator to the model weights with bitsandbytes quantization,
@@ -695,6 +701,7 @@ def _get_quantized_weights_iterator(
695701 # only load the bitsandbytes module when needed
696702 try :
697703 import bitsandbytes
704+ from bitsandbytes .functional import QuantState
698705 if bitsandbytes .__version__ < "0.42.0" :
699706 raise ImportError ("bitsandbytes version is wrong. Please "
700707 "install bitsandbytes>=0.42.0." )
@@ -708,17 +715,63 @@ def _get_quantized_weights_iterator(
708715 model_name_or_path , revision )
709716
710717 quant_state_dict = {}
711- if use_safetensors :
712- weight_iterator = safetensors_weights_iterator (hf_weights_files )
713- else :
714- weight_iterator = pt_weights_iterator (hf_weights_files )
715718
716- def generator ():
719+ def quantized_checkpoint () -> Generator :
720+ # First iterate over all quant state weights
721+ weight_iterator = self ._hf_weight_iter (hf_weights_files ,
722+ use_safetensors )
723+ temp_state_dict = {}
717724 for weight_name , weight_tensor in weight_iterator :
725+ if weight_name .endswith (".weight" ):
726+ continue
727+ # TODO: only nf4 quantization is supported for now
728+ if weight_name .endswith (".quant_state.bitsandbytes__fp4" ):
729+ raise NotImplementedError (
730+ "Only bitsandbytes_nf4 quantization"
731+ f"is supported for now. { weight_name } is fp4 quantized"
732+ )
733+ temp_state_dict [weight_name ] = weight_tensor
734+
735+ # Closure to parse quant_state for each prequant weight
736+ def _parse_quant_state (param_name : str ,
737+ temp_state_dict : Dict ) -> QuantState :
738+ quant_state = {}
739+ for k in temp_state_dict :
740+ if param_name + "." in k :
741+ quant_state [k ] = temp_state_dict [k ]
742+ # bitsandbytes library requires
743+ # weight.quant_state.bitsandbytes__nf4 in CPU
744+ quant_state [param_name +
745+ ".quant_state.bitsandbytes__nf4" ] = quant_state [
746+ param_name +
747+ ".quant_state.bitsandbytes__nf4" ].cpu ().data
748+ return QuantState .from_dict (quant_state , device = "cuda" )
749+
750+ # Second iterate over all prequant and normal weights
751+ # pre quantized weights would have a quant_state
752+ for weight_name , weight_tensor in self ._hf_weight_iter (
753+ hf_weights_files , use_safetensors ):
754+ # Filter out all weights whose suffix is not ".weight"
755+ if not weight_name .endswith (".weight" ):
756+ continue
757+ if weight_name + ".quant_state.bitsandbytes__nf4" \
758+ in temp_state_dict :
759+ quant_state = _parse_quant_state (weight_name ,
760+ temp_state_dict )
761+ weight_name = weight_name .replace (".weight" , ".qweight" )
762+ quant_state_dict [weight_name ] = quant_state
763+ yield weight_name .replace (".weight" ,
764+ ".qweight" ), weight_tensor
765+ else :
766+ yield weight_name , weight_tensor
767+
768+ def generator () -> Generator :
769+ for weight_name , weight_tensor in self ._hf_weight_iter (
770+ hf_weights_files , use_safetensors ):
718771 if any (target_module in weight_name
719772 for target_module in self .target_modules ):
720773 weight_name = weight_name .replace (".weight" , ".qweight" )
721- # bitsandbytes requires data in GPU
774+ # bitsandbytes requires data in GPU
722775 loaded_weight = weight_tensor .cuda ().data
723776 with set_default_torch_dtype (torch .float32 ):
724777 processed_weight , quant_state = quantize_4bit (
@@ -732,6 +785,8 @@ def generator():
732785
733786 yield weight_name , processed_weight
734787
788+ if pre_quant :
789+ return quantized_checkpoint (), quant_state_dict
735790 return generator (), quant_state_dict
736791
737792 def _load_weights (self , model_config : ModelConfig ,
@@ -749,12 +804,21 @@ def _load_weights(self, model_config: ModelConfig,
749804 logger .info ("Loading weights with BitsAndBytes quantization. "
750805 " May take a while ..." )
751806
752- qweight_iterator , quant_state_dict = (
753- self ._get_quantized_weights_iterator (model_config .model ,
754- model_config .revision ))
807+ is_quantized_checkpoint = False
808+ quant_config = getattr (model_config .hf_config , "quantization_config" ,
809+ None )
810+ if quant_config is not None and quant_config .get (
811+ 'quant_method' ) == "bitsandbytes" :
812+ is_quantized_checkpoint = True
813+
814+ qweight_iterator , quant_state_dict = \
815+ self ._get_quantized_weights_iterator (
816+ model_config .model , model_config .revision , is_quantized_checkpoint )
755817
756818 model .load_weights (qweight_iterator )
757819
820+ torch .cuda .empty_cache ()
821+
758822 param_dict = dict (model .named_parameters ())
759823 stacked_quant_state_dict : Dict [str , Dict [int , Any ]] = {}
760824 for quant_param_name in quant_state_dict :
@@ -792,9 +856,9 @@ def _load_weights(self, model_config: ModelConfig,
792856 f"pack_factor not set for parameter { param_name } ." )
793857
794858 num_elements = [0 ] * len (quant_states )
795- for seq , quant_state in enumerate ( quant_states .items () ):
859+ for seq , quant_state in quant_states .items ():
796860 num_elements [seq ] = math .prod (
797- quant_state [ 1 ] .shape ) // pack_ratio
861+ quant_state .shape ) // pack_ratio
798862
799863 offsets = np .concatenate (([0 ], np .cumsum (num_elements )))
800864 set_weight_attrs (param , {"bnb_shard_offsets" : offsets })
0 commit comments