@@ -10,7 +10,7 @@ use clap::Parser;
1010use mistralrs_core:: {
1111 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
1212 paged_attn_supported, parse_isq_value, BertEmbeddingModel , DefaultSchedulerMethod ,
13- DeviceLayerMapMetadata , DeviceMapMetadata , DeviceMapSetting , IsqType , Loader , LoaderBuilder ,
13+ DeviceLayerMapMetadata , DeviceMapMetadata , DeviceMapSetting , Loader , LoaderBuilder ,
1414 MemoryGpuConfig , MistralRs , MistralRsBuilder , ModelSelected , PagedAttentionConfig , Request ,
1515 SchedulerConfig , TokenSource ,
1616} ;
@@ -119,8 +119,8 @@ struct Args {
119119 num_device_layers : Option < Vec < String > > ,
120120
121121 /// In-situ quantization to apply.
122- #[ arg( long = "isq" , value_parser = parse_isq_value ) ]
123- in_situ_quant : Option < IsqType > ,
122+ #[ arg( long = "isq" ) ]
123+ in_situ_quant : Option < String > ,
124124
125125 /// GPU memory to allocate for KV cache with PagedAttention in MBs.
126126 /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
@@ -223,7 +223,7 @@ async fn re_isq(
223223) -> Result < String , String > {
224224 let repr = format ! ( "Re ISQ: {:?}" , request. ggml_type) ;
225225 MistralRs :: maybe_log_request ( state. clone ( ) , repr. clone ( ) ) ;
226- let request = Request :: ReIsq ( parse_isq_value ( & request. ggml_type ) ?) ;
226+ let request = Request :: ReIsq ( parse_isq_value ( & request. ggml_type , None ) ?) ;
227227 state. get_sender ( ) . unwrap ( ) . send ( request) . await . unwrap ( ) ;
228228 Ok ( repr)
229229}
@@ -300,7 +300,12 @@ async fn main() -> Result<()> {
300300 . build ( ) ?;
301301
302302 #[ cfg( feature = "metal" ) ]
303- let device = Device :: new_metal ( 0 ) ?;
303+ let device = if args. cpu {
304+ args. no_paged_attn = true ;
305+ Device :: Cpu
306+ } else {
307+ Device :: new_metal ( 0 ) ?
308+ } ;
304309 #[ cfg( not( feature = "metal" ) ) ]
305310 let device = if args. cpu {
306311 args. no_paged_attn = true ;
@@ -426,14 +431,19 @@ async fn main() -> Result<()> {
426431 ( _, _, _, _, _, _) => None ,
427432 } ;
428433
434+ let isq = args
435+ . in_situ_quant
436+ . as_ref ( )
437+ . and_then ( |isq| parse_isq_value ( isq, Some ( & device) ) . ok ( ) ) ;
438+
429439 let pipeline = loader. load_model_from_hf (
430440 None ,
431441 args. token_source ,
432442 & dtype,
433443 & device,
434444 false ,
435445 mapper,
436- args . in_situ_quant ,
446+ isq ,
437447 cache_config,
438448 ) ?;
439449 info ! ( "Model loaded." ) ;
0 commit comments