@@ -133,6 +133,7 @@ def from_pretrained(
133133        revision                    =  None ,
134134        use_exact_model_name        =  False ,
135135        offload_embedding           =  False ,
136+         float32_mixed_precision     =  None , # Forces float32 mixed precision 
136137
137138        fast_inference              =  False , # uses vLLM 
138139        gpu_memory_utilization      =  0.5 ,
@@ -172,7 +173,7 @@ def from_pretrained(
172173                fullgraph                   =  True , # No graph breaks 
173174                use_exact_model_name        =  use_exact_model_name ,
174175                offload_embedding           =  offload_embedding ,
175- 
176+                  float32_mixed_precision      =   float32_mixed_precision , 
176177                # Pass vLLM/inference parameters 
177178                fast_inference              =  fast_inference ,
178179                gpu_memory_utilization      =  gpu_memory_utilization ,
@@ -449,7 +450,7 @@ def from_pretrained(
449450                fullgraph                   =  True , # No graph breaks 
450451                use_exact_model_name        =  use_exact_model_name ,
451452                offload_embedding           =  offload_embedding ,
452- 
453+                  float32_mixed_precision      =   float32_mixed_precision , 
453454                # Pass vLLM/inference parameters 
454455                fast_inference              =  fast_inference ,
455456                gpu_memory_utilization      =  gpu_memory_utilization ,
@@ -594,7 +595,7 @@ def from_pretrained(
594595        whisper_task                =  None ,
595596        unsloth_force_compile       =  False ,
596597        offload_embedding           =  False ,
597- 
598+          float32_mixed_precision      =   None ,  # Forces float32 mixed precision 
598599        # Add the missing vLLM/inference parameters 
599600        fast_inference              =  False , # uses vLLM 
600601        gpu_memory_utilization      =  0.5 ,
@@ -1008,7 +1009,7 @@ def from_pretrained(
10081009            whisper_task       =  whisper_task ,
10091010            auto_config        =  model_config ,
10101011            offload_embedding  =  offload_embedding ,
1011- 
1012+              float32_mixed_precision   =   float32_mixed_precision , 
10121013            # Pass vLLM/inference parameters 
10131014            fast_inference          =  fast_inference ,
10141015            gpu_memory_utilization  =  gpu_memory_utilization ,
0 commit comments