Skip to content

Commit 1d40a25

Browse files
committed
Cache the user prompt states too.
1 parent 9cced37 commit 1d40a25

File tree

6 files changed

+137
-111
lines changed

6 files changed

+137
-111
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ai00_server"
3-
version = "0.3.14"
3+
version = "0.3.15"
44
edition = "2021"
55
authors = ["Gu ZhenNiu <[email protected]>", "Zhang Zhenyuan <[email protected]>"]
66
license = "MIT OR Apache-2.0"

assets/configs/Config.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ quant_type = "Int8" # Quantization t
55
turbo = true # Whether to use alternative GEMM kernel to speed-up long prompts.
66
token_chunk_size = 32 # Size of token chunk that is inferred at once. For high end GPUs, this could be 64 or 128 (faster).
77
head_chunk_size = 8192 # DO NOT modify this if you don't know what you are doing.
8+
state_chunk_size = 4 # The chunk size of layers in model state.
89
max_runtime_batch = 8 # The maximum batches that can be scheduled for inference at the same time.
910
max_batch = 16 # The maximum batches that are cached on GPU.
1011
embed_layer = 2 # The (reversed) layer number whose output is used as embedding.

src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ impl From<Config> for ReloadRequest {
2525
turbo,
2626
token_chunk_size,
2727
head_chunk_size,
28+
state_chunk_size,
2829
max_runtime_batch,
2930
max_batch,
3031
embed_layer,
@@ -45,6 +46,7 @@ impl From<Config> for ReloadRequest {
4546
turbo,
4647
token_chunk_size,
4748
head_chunk_size,
49+
state_chunk_size,
4850
max_runtime_batch,
4951
max_batch,
5052
embed_layer,
@@ -70,6 +72,8 @@ pub struct Model {
7072
pub token_chunk_size: usize,
7173
/// The chunk size for each split of the head matrix.
7274
pub head_chunk_size: usize,
75+
/// The chunk size of layers in model state.
76+
pub state_chunk_size: usize,
7377
/// Maximum number of batches that are active at once.
7478
pub max_runtime_batch: usize,
7579
/// Number of states that are cached on GPU.
@@ -89,6 +93,7 @@ impl Default for Model {
8993
turbo: true,
9094
token_chunk_size: 32,
9195
head_chunk_size: 8192,
96+
state_chunk_size: 4,
9297
max_runtime_batch: 8,
9398
max_batch: 16,
9499
embed_layer: 2,

src/main.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ mod sampler;
4545
mod utils;
4646

4747
pub const MAX_TOKENS: usize = 4096;
48-
pub const STATE_CHUNK_SIZE: usize = 4;
4948

5049
#[derive(Debug)]
5150
pub enum Token {
@@ -182,6 +181,8 @@ pub struct ReloadRequest {
182181
pub token_chunk_size: usize,
183182
/// The chunk size for each split of the head matrix.
184183
pub head_chunk_size: usize,
184+
/// The chunk size of layers in model state.
185+
pub state_chunk_size: usize,
185186
/// Maximum number of batches that are active at once.
186187
pub max_runtime_batch: usize,
187188
/// Number of states that are cached on GPU.
@@ -290,7 +291,7 @@ where
290291

291292
let state: S = StateBuilder::new(context, model.info())
292293
.with_num_batch(request.max_batch)
293-
.with_chunk_size(STATE_CHUNK_SIZE)
294+
.with_chunk_size(request.state_chunk_size)
294295
.build();
295296
Ok((model, state))
296297
}
@@ -397,6 +398,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
397398
let reload = async move {
398399
let sender = sender.clone();
399400
let max_runtime_batch = request.max_runtime_batch;
401+
let state_chunk_size = request.state_chunk_size;
400402
let embed_layer = request.embed_layer;
401403

402404
let file = File::open(&request.model_path)?;
@@ -419,6 +421,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
419421
model,
420422
state,
421423
max_runtime_batch,
424+
state_chunk_size,
422425
embed_layer,
423426
))
424427
}
@@ -429,6 +432,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
429432
model,
430433
state,
431434
max_runtime_batch,
435+
state_chunk_size,
432436
embed_layer,
433437
))
434438
}
@@ -439,6 +443,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
439443
model,
440444
state,
441445
max_runtime_batch,
446+
state_chunk_size,
442447
embed_layer,
443448
))
444449
}
@@ -492,6 +497,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
492497

493498
let context = GenerateContext {
494499
prompt_tokens: tokens.to_vec(),
500+
prompt_cached: false,
495501
prefix: Default::default(),
496502
suffix: tokens,
497503
penalties,

0 commit comments

Comments
 (0)