Skip to content

Commit ccae7a3

Browse files
committed
feat: add --split-mode CLI option
Signed-off-by: Xin Liu <[email protected]>
1 parent f4f7c15 commit ccae7a3

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/main.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ struct Cli {
7878
/// Number of layers to run on the GPU
7979
#[arg(short = 'g', long, default_value = "100")]
8080
n_gpu_layers: u64,
81+
/// Split the model across multiple GPUs. Possible values: `none` (use one GPU only), `layer` (split layers and KV across GPUs, default), `row` (split rows across GPUs)
82+
#[arg(long, default_value = "layer")]
83+
split_mode: String,
8184
/// The main GPU to use.
8285
#[arg(long)]
8386
main_gpu: Option<u64>,
@@ -246,6 +249,9 @@ async fn main() -> Result<(), ServerError> {
246249
// log n_gpu_layers
247250
info!(target: "stdout", "n_gpu_layers: {}", &cli.n_gpu_layers);
248251

252+
// log split_mode
253+
info!(target: "stdout", "split_mode: {}", cli.split_mode);
254+
249255
// log main GPU
250256
if let Some(main_gpu) = &cli.main_gpu {
251257
info!(target: "stdout", "main_gpu: {}", main_gpu);
@@ -395,6 +401,7 @@ async fn main() -> Result<(), ServerError> {
395401
.with_batch_size(cli.batch_size[0])
396402
.with_n_predict(cli.n_predict)
397403
.with_n_gpu_layers(cli.n_gpu_layers)
404+
.with_split_mode(cli.split_mode.clone())
398405
.with_main_gpu(cli.main_gpu)
399406
.with_tensor_split(cli.tensor_split.clone())
400407
.with_threads(cli.threads)
@@ -418,6 +425,9 @@ async fn main() -> Result<(), ServerError> {
418425
repeat_penalty: chat_metadata.repeat_penalty,
419426
presence_penalty: chat_metadata.presence_penalty,
420427
frequency_penalty: chat_metadata.frequency_penalty,
428+
split_mode: chat_metadata.split_mode.clone(),
429+
main_gpu: chat_metadata.main_gpu,
430+
tensor_split: chat_metadata.tensor_split.clone(),
421431
};
422432

423433
// chat model
@@ -431,6 +441,7 @@ async fn main() -> Result<(), ServerError> {
431441
)
432442
.with_ctx_size(cli.ctx_size[1])
433443
.with_batch_size(cli.batch_size[1])
444+
.with_split_mode(cli.split_mode)
434445
.with_main_gpu(cli.main_gpu)
435446
.with_tensor_split(cli.tensor_split)
436447
.with_threads(cli.threads)
@@ -441,17 +452,20 @@ async fn main() -> Result<(), ServerError> {
441452
let embedding_model_info = ModelConfig {
442453
name: embedding_metadata.model_name.clone(),
443454
ty: "embedding".to_string(),
455+
ctx_size: embedding_metadata.ctx_size,
456+
batch_size: embedding_metadata.batch_size,
444457
prompt_template: embedding_metadata.prompt_template,
445458
n_predict: embedding_metadata.n_predict,
446459
reverse_prompt: embedding_metadata.reverse_prompt.clone(),
447460
n_gpu_layers: embedding_metadata.n_gpu_layers,
448-
ctx_size: embedding_metadata.ctx_size,
449-
batch_size: embedding_metadata.batch_size,
450461
temperature: embedding_metadata.temperature,
451462
top_p: embedding_metadata.top_p,
452463
repeat_penalty: embedding_metadata.repeat_penalty,
453464
presence_penalty: embedding_metadata.presence_penalty,
454465
frequency_penalty: embedding_metadata.frequency_penalty,
466+
split_mode: embedding_metadata.split_mode.clone(),
467+
main_gpu: embedding_metadata.main_gpu,
468+
tensor_split: embedding_metadata.tensor_split.clone(),
455469
};
456470

457471
// embedding model
@@ -686,18 +700,23 @@ pub(crate) struct ModelConfig {
686700
// type: chat or embedding
687701
#[serde(rename = "type")]
688702
ty: String,
703+
pub ctx_size: u64,
704+
pub batch_size: u64,
689705
pub prompt_template: PromptTemplateType,
690706
pub n_predict: i32,
691707
#[serde(skip_serializing_if = "Option::is_none")]
692708
pub reverse_prompt: Option<String>,
693709
pub n_gpu_layers: u64,
694-
pub ctx_size: u64,
695-
pub batch_size: u64,
696710
pub temperature: f64,
697711
pub top_p: f64,
698712
pub repeat_penalty: f64,
699713
pub presence_penalty: f64,
700714
pub frequency_penalty: f64,
715+
pub split_mode: String,
716+
#[serde(skip_serializing_if = "Option::is_none")]
717+
pub main_gpu: Option<u64>,
718+
#[serde(skip_serializing_if = "Option::is_none")]
719+
pub tensor_split: Option<String>,
701720
}
702721

703722
#[derive(Debug, Serialize, Deserialize)]

0 commit comments

Comments
 (0)