@@ -78,6 +78,9 @@ struct Cli {
7878 /// Number of layers to run on the GPU
7979 #[ arg( short = 'g' , long, default_value = "100" ) ]
8080 n_gpu_layers : u64 ,
81+ /// Split the model across multiple GPUs. Possible values: `none` (use one GPU only), `layer` (split layers and KV across GPUs, default), `row` (split rows across GPUs)
82+ #[ arg( long, default_value = "layer" ) ]
83+ split_mode : String ,
8184 /// The main GPU to use.
8285 #[ arg( long) ]
8386 main_gpu : Option < u64 > ,
@@ -246,6 +249,9 @@ async fn main() -> Result<(), ServerError> {
246249 // log n_gpu_layers
247250 info ! ( target: "stdout" , "n_gpu_layers: {}" , & cli. n_gpu_layers) ;
248251
252+ // log split_mode
253+ info ! ( target: "stdout" , "split_mode: {}" , cli. split_mode) ;
254+
249255 // log main GPU
250256 if let Some ( main_gpu) = & cli. main_gpu {
251257 info ! ( target: "stdout" , "main_gpu: {}" , main_gpu) ;
@@ -395,6 +401,7 @@ async fn main() -> Result<(), ServerError> {
395401 . with_batch_size ( cli. batch_size [ 0 ] )
396402 . with_n_predict ( cli. n_predict )
397403 . with_n_gpu_layers ( cli. n_gpu_layers )
404+ . with_split_mode ( cli. split_mode . clone ( ) )
398405 . with_main_gpu ( cli. main_gpu )
399406 . with_tensor_split ( cli. tensor_split . clone ( ) )
400407 . with_threads ( cli. threads )
@@ -418,6 +425,9 @@ async fn main() -> Result<(), ServerError> {
418425 repeat_penalty : chat_metadata. repeat_penalty ,
419426 presence_penalty : chat_metadata. presence_penalty ,
420427 frequency_penalty : chat_metadata. frequency_penalty ,
428+ split_mode : chat_metadata. split_mode . clone ( ) ,
429+ main_gpu : chat_metadata. main_gpu ,
430+ tensor_split : chat_metadata. tensor_split . clone ( ) ,
421431 } ;
422432
423433 // chat model
@@ -431,6 +441,7 @@ async fn main() -> Result<(), ServerError> {
431441 )
432442 . with_ctx_size ( cli. ctx_size [ 1 ] )
433443 . with_batch_size ( cli. batch_size [ 1 ] )
444+ . with_split_mode ( cli. split_mode )
434445 . with_main_gpu ( cli. main_gpu )
435446 . with_tensor_split ( cli. tensor_split )
436447 . with_threads ( cli. threads )
@@ -441,17 +452,20 @@ async fn main() -> Result<(), ServerError> {
441452 let embedding_model_info = ModelConfig {
442453 name : embedding_metadata. model_name . clone ( ) ,
443454 ty : "embedding" . to_string ( ) ,
455+ ctx_size : embedding_metadata. ctx_size ,
456+ batch_size : embedding_metadata. batch_size ,
444457 prompt_template : embedding_metadata. prompt_template ,
445458 n_predict : embedding_metadata. n_predict ,
446459 reverse_prompt : embedding_metadata. reverse_prompt . clone ( ) ,
447460 n_gpu_layers : embedding_metadata. n_gpu_layers ,
448- ctx_size : embedding_metadata. ctx_size ,
449- batch_size : embedding_metadata. batch_size ,
450461 temperature : embedding_metadata. temperature ,
451462 top_p : embedding_metadata. top_p ,
452463 repeat_penalty : embedding_metadata. repeat_penalty ,
453464 presence_penalty : embedding_metadata. presence_penalty ,
454465 frequency_penalty : embedding_metadata. frequency_penalty ,
466+ split_mode : embedding_metadata. split_mode . clone ( ) ,
467+ main_gpu : embedding_metadata. main_gpu ,
468+ tensor_split : embedding_metadata. tensor_split . clone ( ) ,
455469 } ;
456470
457471 // embedding model
@@ -686,18 +700,23 @@ pub(crate) struct ModelConfig {
686700 // type: chat or embedding
687701 #[ serde( rename = "type" ) ]
688702 ty : String ,
703+ pub ctx_size : u64 ,
704+ pub batch_size : u64 ,
689705 pub prompt_template : PromptTemplateType ,
690706 pub n_predict : i32 ,
691707 #[ serde( skip_serializing_if = "Option::is_none" ) ]
692708 pub reverse_prompt : Option < String > ,
693709 pub n_gpu_layers : u64 ,
694- pub ctx_size : u64 ,
695- pub batch_size : u64 ,
696710 pub temperature : f64 ,
697711 pub top_p : f64 ,
698712 pub repeat_penalty : f64 ,
699713 pub presence_penalty : f64 ,
700714 pub frequency_penalty : f64 ,
715+ pub split_mode : String ,
716+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
717+ pub main_gpu : Option < u64 > ,
718+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
719+ pub tensor_split : Option < String > ,
701720}
702721
703722#[ derive( Debug , Serialize , Deserialize ) ]
0 commit comments