@@ -524,8 +524,8 @@ def from_pretrained(
524524 quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING [quantization_config ["quant_method" ]]
525525 quantizer_kwargs = {}
526526 # We cannot dequantize since gpt-oss-20b MXFP4 will now be gpt-oss-20b-BF16
527- # if "dequantize" in inspect.signature(quantizer).parameters:
528- # quantizer_kwargs["dequantize"] = True
527+ if load_in_16bit and "dequantize" in inspect .signature (quantizer ).parameters :
528+ quantizer_kwargs ["dequantize" ] = True
529529 quantization_config = quantizer .from_dict (quantization_config , ** quantizer_kwargs )
530530 kwargs ["quantization_config" ] = quantization_config
531531 pass
@@ -549,7 +549,7 @@ def from_pretrained(
549549 # attn_implementation = attn_implementation,
550550 ** kwargs ,
551551 )
552- if hasattr (model , ' generate' ):
552+ if hasattr (model , " generate" ):
553553 model .fast_generate = model .generate
554554 model .fast_generate_batches = error_out_no_vllm
555555 if offload_embedding :
@@ -612,8 +612,17 @@ def from_pretrained(
612612 llm = load_vllm (** load_vllm_kwargs )
613613
614614 # Convert to HF format
615- _ , quant_state_dict = get_vllm_state_dict (llm , config = model_config , is_vision_model = True )
616- model = convert_vllm_to_huggingface (quant_state_dict , model_config , dtype , bnb_config , is_vision_model = True )
615+ _ , quant_state_dict = get_vllm_state_dict (
616+ llm ,
617+ config = model_config ,
618+ is_vision_model = True ,
619+ )
620+ model = convert_vllm_to_huggingface (
621+ quant_state_dict ,
622+ model_config ,
623+ dtype , bnb_config ,
624+ is_vision_model = True ,
625+ )
617626 model .vllm_engine = llm
618627 model .fast_generate = model .vllm_engine .generate
619628 model .fast_generate_batches = functools .partial (generate_batches , model .vllm_engine )
@@ -753,52 +762,6 @@ def from_pretrained(
753762 return model , tokenizer
754763 pass
755764
756- @staticmethod
757- def pre_compile_for_inference (model_type , model , tokenizer ):
758- """
759- We need to invoke torch.compile to save VRAM usage and make it faster downstream.
760- Sometimes torch.compile can use 3GB weirdly on large batches, then it goes down to <1GB.
761- So we invoke torch.compile on short batches to reduce VRAM usage.
762- """
763- if model_type is None or model is None or tokenizer is None : return
764- if str (model_type ).lower () not in PRE_COMPILE_INFERENCE : return
765- if getattr (tokenizer , "chat_template" , None ) is None : return
766- # Check if already compiled and exit
767- for module in model .modules ():
768- if hasattr (module , "_pre_compiled_for_inference" ): return
769- pass
770- print (f"🦥 Unsloth: Pre compiling { model_type .title ()} model for faster inference - this might take 3 minutes or so!" )
771- print ("========= Pre compiling model for faster inference. Please be patient thank you! =========" )
772- # Do single inference
773- messages = [
774- [
775- {"role" : "user" , "content" : f"What is 1+1 equal to?" },
776- ],
777- ]* 1
778- inputs = tokenizer .apply_chat_template (
779- messages ,
780- add_generation_prompt = True ,
781- return_tensors = "pt" ,
782- return_dict = True ,
783- ).to (model .device )
784- _ = model .generate (** inputs , max_new_tokens = 1 )
785- # Do batched inference
786- messages = [
787- [
788- {"role" : "user" , "content" : f"1+1" },
789- ],
790- ]* 4
791- inputs = tokenizer .apply_chat_template (
792- messages ,
793- add_generation_prompt = True ,
794- return_tensors = "pt" ,
795- return_dict = True ,
796- ).to (model .device )
797- _ = model .generate (** inputs , max_new_tokens = 2 )
798- # Set we already pre compiled
799- model ._pre_compiled_for_inference = True
800- pass
801-
802765 @staticmethod
803766 def get_peft_model (
804767 model ,
@@ -902,7 +865,11 @@ def get_peft_model(
902865 # Enable gradients on modules which are trainable
903866 requires_grad_for_gradient_checkpointing (model )
904867 trust_remote_code = getattr (model , "_unsloth_trust_remote_code" , False )
905- model = FastBaseModel .post_patch_model (model , use_gradient_checkpointing , trust_remote_code = trust_remote_code )
868+ model = FastBaseModel .post_patch_model (
869+ model ,
870+ use_gradient_checkpointing = use_gradient_checkpointing ,
871+ trust_remote_code = trust_remote_code ,
872+ )
906873 model .max_seq_length = max_seq_length
907874 # Save to modules as well
908875 for module in model .modules ():
@@ -998,14 +965,15 @@ def post_patch_model(
998965 m .for_inference = functools .partial (FastBaseModel .for_inference , m )
999966 m = m .model
1000967 # Set weight[padding_idx] = 0
1001- with torch .no_grad ():
1002- for name , module in model .named_modules ():
1003- if type (module ) is torch .nn .Embedding :
1004- if getattr (module , "weight" , None ) is not None and getattr (module , "padding_idx" , None ) is not None :
1005- if module .padding_idx < module .weight .shape [0 ]:
1006- module .weight [module .padding_idx ] = 0
1007- # Patch for torch.compiled inference
1008- # FastBaseModel.pre_compile_for_inference(model_type, model, tokenizer)
968+ # Only do this if tokenizer is defined since eos_token == pad_token sometimes!
969+ pad_token_id = getattr (tokenizer , "pad_token_id" , None )
970+ if tokenizer is not None and getattr (tokenizer , "eos_token_id" , None ) != pad_token_id :
971+ with torch .no_grad ():
972+ for name , module in model .named_modules ():
973+ if type (module ) is torch .nn .Embedding :
974+ if getattr (module , "weight" , None ) is not None and getattr (module , "padding_idx" , None ) is not None :
975+ if module .padding_idx == pad_token_id and module .padding_idx < module .weight .shape [0 ]:
976+ module .weight [module .padding_idx ] = 0
1009977 return model
1010978 pass
1011979
0 commit comments