diff --git a/Cargo.lock b/Cargo.lock index 40cee1b7..9bb3c35b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,7 @@ dependencies = [ "derivative", "fastrand", "flume", + "futures", "half", "itertools 0.14.0", "kbnf", diff --git a/crates/ai00-core/Cargo.toml b/crates/ai00-core/Cargo.toml index 161619b8..ffd22e55 100644 --- a/crates/ai00-core/Cargo.toml +++ b/crates/ai00-core/Cargo.toml @@ -15,6 +15,7 @@ version.workspace = true bytemuck = "1" cbor4ii = { version = "1.0.0", features = ["serde1"] } fastrand = "2" +futures = "0.3" half = "2.4" kbnf = "0.5.6" qp-trie = "0.8" diff --git a/crates/ai00-core/src/lib.rs b/crates/ai00-core/src/lib.rs index ea9f39f1..c103c72c 100644 --- a/crates/ai00-core/src/lib.rs +++ b/crates/ai00-core/src/lib.rs @@ -7,6 +7,7 @@ use std::{ use anyhow::{bail, Result}; use derivative::Derivative; use flume::{Receiver, Sender}; +use futures::future::join_all; use half::f16; use itertools::Itertools; use memmap2::Mmap; @@ -17,25 +18,26 @@ use serde::{de::DeserializeSeed, Deserialize, Serialize}; use tokio::{ fs::File, io::{AsyncReadExt, BufReader}, - sync::{Mutex, RwLock}, + sync::RwLock, time::Duration, }; use web_rwkv::{ context::{Context, ContextBuilder, InstanceExt}, runtime::{ + infer::{InferInput, InferOutput}, loader::{Loader, Lora, LoraBlend, Reader}, - model::{ContextAutoLimits, EmbedDevice, ModelBuilder, ModelInfo, ModelVersion, Quant}, - v4, v5, v6, v7, + model::{ + Bundle, ContextAutoLimits, EmbedDevice, ModelBuilder, ModelInfo, ModelVersion, Quant, + State, + }, + v4, v5, v6, v7, Runtime, TokioRuntime, }, tensor::{serialization::Seed, TensorCpu}, tokenizer::Tokenizer, wgpu::{Backends, PowerPreference}, }; -use crate::{ - run::{GenerateContext, InitState, Runtime, StateId, Tokens}, - sampler::Sampler, -}; +use crate::{run::GenerateContext, sampler::Sampler}; pub mod reload; pub mod run; @@ -115,19 +117,52 @@ pub enum ThreadRequest { #[derive(Default)] pub enum Environment { - Loaded(Runtime), + Loaded { + info: RuntimeInfo, + runtime: Arc, + model: Arc, + sender: Sender, + }, #[default] None, } -#[derive(Debug, Clone)] +#[derive(Derivative, Clone)] +#[derivative(Debug)] pub struct RuntimeInfo { - pub reload: ReloadRequest, - pub model: ModelInfo, - pub states: Vec<(StateId, InitState)>, + pub reload: Arc, + pub info: ModelInfo, + pub states: Vec, pub tokenizer: Arc, } +struct Model(M); + +pub trait ModelSerialize { + fn serialize(&self, file: std::fs::File) -> Result<()>; +} + +impl ModelSerialize for Model { + fn serialize(&self, file: std::fs::File) -> Result<()> { + use cbor4ii::{core::enc::Write, serde::Serializer}; + use std::{fs::File, io::Write as _}; + + struct FileWriter(File); + impl Write for FileWriter { + type Error = std::io::Error; + fn push(&mut self, input: &[u8]) -> Result<(), Self::Error> { + self.0.write_all(input) + } + } + + let file = FileWriter(file); + let mut serializer = Serializer::new(file); + self.0.serialize(&mut serializer)?; + + Ok(()) + } +} + #[derive(Debug, Default, Clone)] pub struct AdapterList(pub Vec); @@ -228,6 +263,29 @@ enum LoadType { Prefab, } +#[derive( + Derivative, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema, +)] +#[derivative(Debug = "transparent")] +#[serde(transparent)] +pub struct StateId(uuid::Uuid); + +impl StateId { + pub fn new() -> Self { + Self(uuid::Uuid::new_v4()) + } +} + +#[derive(Derivative, Clone)] +#[derivative(Debug)] +pub struct InitState { + pub name: String, + pub id: StateId, + pub default: bool, + #[derivative(Debug = "ignore")] + pub data: TensorCpu, +} + fn list_adapters() -> AdapterList { let backends = Backends::all(); let instance = web_rwkv::wgpu::Instance::default(); @@ -267,51 +325,50 @@ async fn load_tokenizer(path: impl AsRef) -> Result { Ok(Tokenizer::new(&contents)?) } -async fn load_init_state( +async fn load_model_state( context: &Context, info: &ModelInfo, model: R, ) -> Result> { - let state = match info.version { + match info.version { ModelVersion::V4 => bail!("v4 does not support init state yet"), ModelVersion::V5 => v5::read_state(context, info, model).await, ModelVersion::V6 => v6::read_state(context, info, model).await, ModelVersion::V7 => v7::read_state(context, info, model).await, - }; - state.map_err(Into::into) + } } async fn load_runtime( context: &Context, - reload: &ReloadRequest, - info: ModelInfo, + info: &ModelInfo, + request: &ReloadRequest, load: LoadType, -) -> Result { +) -> Result<( + Vec, + Arc, + Arc, + Arc, +)> { let ReloadRequest { model_path, lora, state, quant, quant_type, + precision, max_batch, embed_device, - tokenizer_path, .. - } = reload.clone(); - - let tokenizer = load_tokenizer(tokenizer_path).await?; - - let file = File::open(model_path).await?; - let data = unsafe { Mmap::map(&file) }?; - - let mut states = vec![]; - for reload::State { - path, - name, - id, - default, - } in state - { + } = request.clone(); + + let mut states = Vec::with_capacity(state.len()); + for state in state.into_iter() { + let reload::State { + path, + name, + id, + default, + } = state; let name = match name { Some(name) => name, None => match path.file_name() { @@ -322,7 +379,7 @@ async fn load_runtime( let file = File::open(path).await?; let data = unsafe { Mmap::map(&file) }?; let model = SafeTensors::deserialize(&data)?; - match load_init_state(context, &info, model).await { + match load_model_state(context, info, model).await { Ok(data) => { let state = InitState { name, @@ -337,10 +394,13 @@ async fn load_runtime( } } - let runtime = match load { + let file = File::open(model_path).await?; + let data = unsafe { Mmap::map(&file) }?; + + match load { LoadType::SafeTensors => { let model = SafeTensors::deserialize(&data)?; - if let Ok(data) = load_init_state(context, &info, model).await { + if let Ok(data) = load_model_state(context, info, model).await { let name = "internal".into(); let id = StateId::new(); let state = InitState { @@ -354,16 +414,15 @@ async fn load_runtime( let model = SafeTensors::deserialize(&data)?; let quant = (0..quant).map(|layer| (layer, quant_type)).collect(); - let lora = { - let mut x = Vec::with_capacity(lora.len()); - for lora in lora.into_iter() { - let file = File::open(lora.path).await?; - let data = unsafe { Mmap::map(&file) }?; - let blend = LoraBlend::full(lora.alpha); - x.push((data, blend)) - } - x - }; + let lora: Vec> = join_all(lora.iter().map(|lora| async move { + let reload::Lora { path, alpha } = lora; + let file = File::open(path).await?; + let data = unsafe { Mmap::map(&file)? }; + let blend = LoraBlend::full(*alpha); + Ok((data, blend)) + })) + .await; + let lora: Vec<_> = lora.into_iter().try_collect()?; let lora: Vec<_> = lora .iter() .map(|(data, blend)| -> Result<_> { @@ -376,10 +435,7 @@ async fn load_runtime( let builder = ModelBuilder::new(context, model) .quant(quant) .embed_device(embed_device); - let builder = lora.into_iter().fold(builder, |b, x| b.lora(x)); - - let context = context.clone(); - let reload = reload.clone(); + let builder = lora.into_iter().fold(builder, |builder, x| builder.lora(x)); macro_rules! match_safe_tensors { (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $build:ident, $bundle:ty)),+ }) => { @@ -388,14 +444,17 @@ async fn load_runtime( ($version, $precision) => { let model = builder.$build().await?; let bundle = <$bundle>::new(model, max_batch); - Runtime::new(context, bundle, reload, states, tokenizer).await + let state = Arc::new(bundle.state()); + let model = Arc::new(Model(bundle.model())); + let runtime = Arc::new(TokioRuntime::::new(bundle).await); + Ok((states, runtime, state, model)) } )+ } } } match_safe_tensors!( - (info.version, reload.precision), + (info.version, precision), { (ModelVersion::V4, Precision::Fp16, v4::Model, build_v4, v4::Bundle::), (ModelVersion::V5, Precision::Fp16, v5::Model, build_v5, v5::Bundle::), @@ -414,26 +473,25 @@ async fn load_runtime( let reader = SliceReader::new(&data); let mut deserializer = Deserializer::new(reader); - let context = context.clone(); - let reload = reload.clone(); - macro_rules! match_prefab { (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $bundle:ty)),+ }) => { match ($v, $p) { $( ($version, $precision) => { - let seed: Seed<_, $model> = Seed::new(&context); + let seed: Seed<_, $model> = Seed::new(context); let model = seed.deserialize(&mut deserializer)?; - let bundle = <$bundle>::new(model, reload.max_batch); - Runtime::new(context, bundle, reload, states, tokenizer).await + let bundle = <$bundle>::new(model, max_batch); + let state = Arc::new(bundle.state()); + let model = Arc::new(Model(bundle.model())); + let runtime = Arc::new(TokioRuntime::::new(bundle).await); + Ok((states, runtime, state, model)) } )+ - // (version, _) => bail!("unsupported version: {:?}", version) } } } match_prefab!( - (info.version, reload.precision), + (info.version, precision), { (ModelVersion::V4, Precision::Fp16, v4::Model, v4::Bundle::), (ModelVersion::V5, Precision::Fp16, v5::Model, v5::Bundle::), @@ -446,292 +504,140 @@ async fn load_runtime( } ) } - }; - - Ok(runtime) + } } -pub async fn model_route(receiver: Receiver) -> Result<()> { - let env: Arc> = Default::default(); - let queue: Arc>> = Default::default(); +async fn process(env: Arc>, request: ThreadRequest) -> Result<()> { + match request { + ThreadRequest::Adapter(sender) => { + let _ = sender.send(list_adapters()); + } + ThreadRequest::Info(sender) => { + let env = env.read().await; + if let Environment::Loaded { info, .. } = &*env { + let _ = sender.send(info.clone()); + } + } + ThreadRequest::Generate { + request, + tokenizer, + sender, + } => { + let context = GenerateContext::new(*request, sender, &tokenizer).await?; + let env = env.read().await; + if let Environment::Loaded { sender, .. } = &*env { + let _ = sender.send(context); + } + } + ThreadRequest::Reload { request, sender } => { + let handle = tokio::spawn(async move { + let file = File::open(&request.model_path).await?; + let data = unsafe { Mmap::map(&file)? }; + let (info, load) = { + let st = SafeTensors::deserialize(&data); + let prefab = cbor4ii::serde::from_slice::(&data); + match (st, prefab) { + (Ok(model), _) => (Loader::info(&model)?, LoadType::SafeTensors), + (_, Ok(prefab)) => (prefab.info, LoadType::Prefab), + _ => bail!("failed to read model info"), + } + }; + log::info!("{:#?}", request); + log::info!("{:#?}", info); + log::info!("model type: {:?}", load); - let sender = { - let (sender, receiver) = flume::unbounded(); - let env = env.clone(); - tokio::spawn(crate::run::run(receiver, env)); - sender - }; + let context = create_context(request.adapter, &info).await?; + log::info!("{:#?}", context.adapter.get_info()); - let dequeue = { - let env = env.clone(); - let queue = queue.clone(); - let sender = sender.clone(); - - async move { - loop { - let mut queue = queue.lock().await; - let mut temp = vec![]; - for context in queue.drain(..) { - temp.append(&mut env.read().await.enqueue(context).await); - let _ = sender.send(()); - } - std::mem::swap(&mut *queue, &mut temp); - drop(queue); + let mut env = env.write().await; + let _ = std::mem::take(&mut *env); - tokio::time::sleep(Duration::from_secs(1)).await; - } - } - }; - tokio::spawn(dequeue); + let tokenizer = Arc::new(load_tokenizer(&request.tokenizer_path).await?); - loop { - let Ok(request) = receiver.recv_async().await else { - log::info!("core exit"); - break Ok(()); - }; + let (states, runtime, state, model) = + load_runtime(&context, &info, &request, load).await?; - let listen = async { - match request { - ThreadRequest::Adapter(sender) => { - tokio::spawn(async move { - let _ = sender.send(list_adapters()); - }); - } - ThreadRequest::Info(sender) => { - let env = env.clone(); - tokio::spawn(async move { - let env = &(*env.read().await); - if let Environment::Loaded(runtime) = env { - let reload = runtime.reload().clone(); - let model = runtime.info().clone(); - let states = runtime.states().await; - let tokenizer = runtime.tokenizer(); - let _ = sender.send(RuntimeInfo { - reload, - model, - states, - tokenizer, - }); - } - }); - } - ThreadRequest::Reload { - request, - sender: reload_sender, - } => { - let request = *request; - let sender = sender.clone(); - let env = env.clone(); - let reload = async move { - let sender = sender.clone(); - - let file = File::open(&request.model_path).await?; - let data = unsafe { Mmap::map(&file)? }; - - let (info, load) = { - let st = SafeTensors::deserialize(&data); - let prefab = cbor4ii::serde::from_slice::(&data); - match (st, prefab) { - (Ok(model), _) => (Loader::info(&model)?, LoadType::SafeTensors), - (_, Ok(prefab)) => (prefab.info, LoadType::Prefab), - _ => bail!("failed to read model info"), - } - }; - log::info!("{:#?}", request); - log::info!("{:#?}", info); - log::info!("model type: {:?}", load); - - let context = create_context(request.adapter, &info).await?; - log::info!("{:#?}", context.adapter.get_info()); - - let mut env = env.write().await; - // drop(mem::take(&mut *env)); - 'unload: { - let env = std::mem::take(&mut *env); - let _context = match env { - Environment::Loaded(runtime) => runtime.context().clone(), - Environment::None => break 'unload, - }; - } - - let runtime = load_runtime(&context, &request, info, load).await?; - *env = Environment::Loaded(runtime); - - let _ = sender.send(()); - anyhow::Ok(()) - }; - let callback = move |result: bool| { - if let Some(sender) = reload_sender { - let _ = sender.send(result); - } - }; - tokio::spawn(async move { - match reload.await { - Ok(_) => { - callback(true); - log::info!("model loaded") - } - Err(err) => { - callback(false); - log::error!("load runtime failed: {}", err); - } - }; - }); - } - ThreadRequest::Unload => { - let env = env.clone(); - tokio::spawn(async move { - let mut env = env.write().await; - let env = std::mem::take(&mut *env); - log::info!("runtime unloaded"); - - let _context = match env { - Environment::Loaded(runtime) => runtime.context().clone(), - Environment::None => return, - }; - }); - } - ThreadRequest::StateLoad { request, sender } => { - let env = env.clone(); - let load = async move { - let env = env.read().await; - let Environment::Loaded(runtime) = &*env else { - bail!("runtime not loaded") - }; - - let reload::State { - path, - name, - id, - default, - } = request; - let name = match name { - Some(name) => name, - None => match path.file_name() { - Some(name) => name.to_string_lossy().to_string(), - None => bail!("failed to parse state name"), - }, - }; - let file = File::open(&path).await?; - let data = unsafe { Mmap::map(&file)? }; - - let context = runtime.context(); - let info = runtime.info(); - let model = SafeTensors::deserialize(&data)?; - match load_init_state(context, info, model).await { - Ok(data) => { - let state = InitState { - name, - id, - data, - default, - }; - log::info!("{:#?}", state); - runtime.load_init_state(state).await; - } - Err(err) => log::warn!("initial state not loaded: {}", err), - }; - Ok(()) - }; - let callback = move |result: bool| { - if let Some(sender) = sender { - let _ = sender.send(result); - } - }; - tokio::spawn(async move { - match load.await { - Ok(_) => { - callback(true); - log::info!("state loaded") - } - Err(err) => { - callback(false); - log::error!("load state failed: {}", err); - } - }; - }); - } - ThreadRequest::StateUnload(id) => { - let env = env.clone(); - tokio::spawn(async move { - let env = env.read().await; - let Environment::Loaded(runtime) = &*env else { - return; - }; - runtime.unload_init_state(id).await; - }); - } - ThreadRequest::Generate { - request, + let reload = Arc::new(*request); + let info = RuntimeInfo { + reload, + info, + states, tokenizer, - sender: token_sender, - } => { - let request = *request; - let tokens = Tokens(tokenizer.encode(request.prompt.as_bytes())?); - let model_tokens = Tokens(tokenizer.encode(request.model_text.as_bytes())?); - // init sampler state here - request.sampler.write().await.init(&model_tokens); - - let choices = match &request.kind { - GenerateKind::Choose { choices, .. } => { - let choices: Vec<_> = choices - .iter() - .map(|prompt| tokenizer.encode(prompt.as_bytes())) - .try_collect()?; - choices.into_iter().map(Tokens).collect() - } - _ => vec![], - }; - - let context = GenerateContext { - prompt_tokens: tokens.to_vec(), - prompt_cached: Default::default(), - prefix: Default::default(), - suffix: tokens, - output: None, - choices, - model_text: vec![], - buffer: vec![], - model_tokens: vec![], - formatters: vec![], - instant: None, - request, - sender: token_sender, - }; - - let env = env.clone(); - let queue = queue.clone(); - let sender = sender.clone(); - tokio::spawn(async move { - let context = &mut env.read().await.enqueue(context).await; - let mut queue = queue.lock().await; - queue.append(context); - let _ = sender.send(()); - }); - } - ThreadRequest::Save { request, sender } => { - let env = env.clone(); - tokio::spawn(async move { - let env = &(*env.read().await); - if let Environment::Loaded(runtime) = env { - log::info!("serializing model into {:?}", &request.path); - let _ = match runtime.serialize_model(request.path).await { - Ok(()) => sender.send(true), - Err(err) => { - log::error!("{}", err); - sender.send(false) - } - }; - } - }); - } - }; - anyhow::Ok(()) - }; + }; + + let sender = { + let runtime = Arc::downgrade(&runtime); + let (sender, receiver) = flume::unbounded(); + tokio::spawn(crate::run::run( + context, + runtime, + state, + receiver, + info.clone(), + )); + sender + }; - if let Err(err) = listen.await { - log::error!("{err}"); + log::info!("model loaded"); + + let _ = std::mem::replace( + &mut *env, + Environment::Loaded { + info, + runtime, + model, + sender, + }, + ); + Ok(()) + }); + + if let Some(sender) = sender { + let _ = match handle.await? { + Ok(_) => sender.send(true), + Err(err) => { + log::error!("[reload] error: {err}"); + sender.send(false) + } + }; + } + } + ThreadRequest::Unload => { + let mut env = env.write().await; + let _ = std::mem::take(&mut *env); + log::info!("model unloaded"); + } + ThreadRequest::StateLoad { .. } => log::error!("[state] method unimplemented"), + ThreadRequest::StateUnload(_) => log::error!("[state] method unimplemented"), + ThreadRequest::Save { request, sender } => { + let env = env.read().await; + if let Environment::Loaded { model, .. } = &*env { + log::info!("serializing model into {:?}", &request.path); + let model = model.clone(); + let handle = tokio::task::spawn_blocking(move || { + let file = std::fs::File::create(request.path)?; + model.serialize(file) + }); + drop(env); + + let _ = match handle.await? { + Ok(_) => sender.send(true), + Err(err) => { + log::error!("[save] error: {err}"); + sender.send(false) + } + }; + } } + }; + Ok(()) +} + +pub async fn serve(receiver: Receiver) { + let env: Arc> = Default::default(); + while let Ok(request) = receiver.recv_async().await { + let future = process(env.clone(), request); + tokio::spawn(future); } } diff --git a/crates/ai00-core/src/reload.rs b/crates/ai00-core/src/reload.rs index 85d00571..65ec153e 100644 --- a/crates/ai00-core/src/reload.rs +++ b/crates/ai00-core/src/reload.rs @@ -5,7 +5,7 @@ use salvo::oapi::ToSchema; use serde::{Deserialize, Serialize}; use web_rwkv::runtime::model::{EmbedDevice, Quant}; -use crate::run::StateId; +use crate::StateId; #[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)] #[derivative(Default)] diff --git a/crates/ai00-core/src/run.rs b/crates/ai00-core/src/run.rs index 6e7ffcd7..a52e58c6 100644 --- a/crates/ai00-core/src/run.rs +++ b/crates/ai00-core/src/run.rs @@ -1,134 +1,42 @@ use std::{ cmp::Ordering, collections::HashMap, + error::Error, ops::Deref, - path::PathBuf, - sync::Arc, - time::{Duration, Instant}, + sync::{Arc, Weak}, + time::Duration, }; use anyhow::Result; use derivative::Derivative; -use flume::{Receiver, Sender}; +use flume::{Receiver, Sender, TryRecvError}; use itertools::Itertools; use qp_trie::Trie; -use salvo::oapi::ToSchema; -use serde::{Deserialize, Serialize}; -use tokio::sync::{Mutex, RwLock}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinHandle, + time::Instant, +}; use web_rwkv::{ context::Context, runtime::{ - infer::{InferInfo, InferInput, InferInputBatch, InferOption, InferOutput}, - model::{Bundle, ModelInfo, State}, - softmax::softmax, - Dispatcher, Job, TokioRuntime, + infer::{InferInput, InferInputBatch, InferOption, InferOutputBatch}, + model::{ModelInfo, State}, + Runtime, }, - tensor::{TensorCpu, TensorInit}, + tensor::{TensorCpu, TensorError}, tokenizer::Tokenizer, }; use crate::{ - sampler::{bnf::BnfSampler, Formatter}, - Environment, FinishReason, GenerateKind, GenerateRequest, ReloadRequest, Token, TokenCounter, + sampler::{bnf::BnfSampler, Formatter, Sampler}, + FinishReason, GenerateKind, GenerateRequest, InitState, ReloadRequest, RuntimeInfo, StateId, + Token, TokenCounter, }; const MIN_PROMPT_CACHE_TOKENS: usize = 32; const MAX_CACHE_ITEMS: usize = 256; -#[derive(Debug)] -pub enum SlotResult { - /// There is an idle slot ready to be picked up. - Success(usize), - /// An idle slot is swapped. - Fault(usize), - /// There is no idle slot left. - Failure(Box), - /// An error occurred. - Error(String), -} - -#[derive(Debug)] -enum SlotState { - /// The slot might be either picked up or swapped. - Idle(Tokens, Instant), - /// The slot is locked and is waiting for processing. - Wait(Box), - /// The slot is currently under processing. - Busy, -} - -impl Default for SlotState { - fn default() -> Self { - Self::Idle(Default::default(), Instant::now()) - } -} - -#[derive(Debug, PartialEq, Eq)] -enum SlotChoice { - Continue(usize, usize), - Back(usize), - Empty(usize), -} - -impl std::cmp::Ord for SlotChoice { - fn cmp(&self, other: &Self) -> Ordering { - // priority: continue > empty > back - use SlotChoice::{Back, Continue, Empty}; - match (self, other) { - (Continue(_, x), Continue(_, y)) => x.cmp(y), - (Continue(_, _), _) => Ordering::Greater, - (_, Continue(_, _)) => Ordering::Less, - (Empty(_), Empty(_)) => Ordering::Equal, - (Empty(_), Back(_)) => Ordering::Greater, - (Back(_), Empty(_)) => Ordering::Less, - (Back(_), Back(_)) => Ordering::Equal, - } - } -} - -impl std::cmp::PartialOrd for SlotChoice { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug, Clone, Default)] -pub enum Payload { - #[default] - Empty, - Busy(GenerateContext), - Done(GenerateContext), -} - -impl Payload { - /// Takes out the value if `self` is [`Payload::Done`], and reset `self` to [`Payload::Empty`]. - pub fn take(&mut self) -> Option { - match std::mem::take(self) { - Payload::Done(context) => Some(context), - payload => { - *self = payload; - None - } - } - } - - /// Set `self` to [`Payload::Done`] if `self` is [`Payload::Busy`]. - pub fn finalize(&mut self) { - *self = match std::mem::take(self) { - Payload::Busy(context) => Payload::Done(context), - payload => payload, - } - } - - /// Returns `true` if the payload is [`Empty`]. - /// - /// [`Empty`]: Payload::Empty - #[must_use] - pub fn is_empty(&self) -> bool { - matches!(self, Self::Empty) - } -} - #[repr(transparent)] #[derive(Debug, Default, Clone)] pub struct Tokens(pub Vec); @@ -237,6 +145,46 @@ pub struct GenerateContext { pub sender: Sender, } +impl GenerateContext { + pub async fn new( + request: GenerateRequest, + sender: Sender, + tokenizer: &Tokenizer, + ) -> Result { + let tokens = Tokens(tokenizer.encode(request.prompt.as_bytes())?); + let model_tokens = Tokens(tokenizer.encode(request.model_text.as_bytes())?); + + // init sampler state here + request.sampler.write().await.init(&model_tokens); + + let choices = match &request.kind { + GenerateKind::Choose { choices, .. } => { + let choices: Vec<_> = choices + .iter() + .map(|prompt| tokenizer.encode(prompt.as_bytes())) + .try_collect()?; + choices.into_iter().map(Tokens).collect() + } + _ => Vec::new(), + }; + Ok(Self { + prompt_tokens: tokens.to_vec(), + prompt_cached: Default::default(), + prefix: Default::default(), + suffix: tokens, + output: None, + choices, + model_text: Vec::new(), + buffer: Vec::new(), + model_tokens: Vec::new(), + formatters: Vec::new(), + instant: None, + request, + sender, + }) + } +} + #[derive(Debug, Default, Clone)] pub enum CachedPrompt { #[default] @@ -321,201 +269,109 @@ impl CacheHub { } } -#[derive( - Derivative, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema, -)] -#[derivative(Debug = "transparent")] -#[serde(transparent)] -pub struct StateId(uuid::Uuid); - -impl StateId { - pub fn new() -> Self { - Self(uuid::Uuid::new_v4()) - } +/// The result of trying to queuing a task. +#[derive(Debug)] +enum SlotResult { + /// There is an idle slot ready to be picked up. + Success(usize), + /// An idle slot is swapped. + Fault(usize), + /// There is no idle slot left. + Failure(Box), + /// An error occurred. + Error(Box), } -#[derive(Derivative, Clone)] -#[derivative(Debug)] -pub struct InitState { - pub name: String, - pub id: StateId, - pub default: bool, - #[derivative(Debug = "ignore")] - pub data: TensorCpu, +#[derive(Debug)] +enum SlotState { + /// The slot might be either picked up or swapped. + Idle(Tokens, Instant), + /// The slot is currently under processing. + Busy(JoinHandle>), + /// The slot is locked for updating. + Locked, } -struct Model(M); - -trait ModelSerialize { - fn serialize(&self, file: std::fs::File) -> Result<()>; +impl Default for SlotState { + fn default() -> Self { + Self::Idle(Default::default(), Instant::now()) + } } -impl ModelSerialize for Model { - fn serialize(&self, file: std::fs::File) -> Result<()> { - use cbor4ii::{core::enc::Write, serde::Serializer}; - use std::{fs::File, io::Write as _}; +#[derive(Debug, PartialEq, Eq)] +enum SlotChoice { + Continue(usize, usize), + Back(usize), + Empty(usize), +} - struct FileWriter(File); - impl Write for FileWriter { - type Error = std::io::Error; - fn push(&mut self, input: &[u8]) -> Result<(), Self::Error> { - self.0.write_all(input) - } +impl std::cmp::Ord for SlotChoice { + fn cmp(&self, other: &Self) -> Ordering { + // priority: continue > empty > back + use SlotChoice::{Back, Continue, Empty}; + match (self, other) { + (Continue(_, x), Continue(_, y)) => x.cmp(y), + (Continue(_, _), _) => Ordering::Greater, + (_, Continue(_, _)) => Ordering::Less, + (Empty(_), Empty(_)) => Ordering::Equal, + (Empty(_), Back(_)) => Ordering::Greater, + (Back(_), Empty(_)) => Ordering::Less, + (Back(_), Back(_)) => Ordering::Equal, } - - let file = FileWriter(file); - let mut serializer = Serializer::new(file); - self.0.serialize(&mut serializer)?; - - Ok(()) } } -impl Environment { - pub async fn enqueue(&self, context: GenerateContext) -> Vec { - let mut queue = vec![]; - match self { - Environment::Loaded(runtime) => { - match runtime.queue(context).await.expect("queue task error") { - SlotResult::Success(batch) => log::info!("queued task at slot {batch}"), - SlotResult::Fault(batch) => log::info!("swapped task at slot {batch}"), - SlotResult::Failure(context) => queue.push(*context), - SlotResult::Error(reason) => log::warn!("queue task failed: {}", reason), - } - } - Environment::None => queue.push(context), - }; - queue +impl std::cmp::PartialOrd for SlotChoice { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } -pub struct Runtime { +#[derive(Debug, Clone)] +enum InferBatch { + Run { + batch: usize, + tokens: Vec, + option: InferOption, + sender: Sender>, + }, + Load { + batch: usize, + tensor: TensorCpu, + }, + Back { + batch: usize, + sender: Sender>, + }, +} + +#[derive(Debug, Clone)] +struct SoftmaxBatch { + input: TensorCpu, + sender: Sender>, +} + +#[derive(Debug, Clone)] +struct RuntimeSender { + infer: Sender, + softmax: Sender, +} + +#[derive(Derivative, Clone)] +#[derivative(Debug)] +struct CoreRuntime { context: Context, - reload: ReloadRequest, info: ModelInfo, + reload: Arc, + #[derivative(Debug = "ignore")] state: Arc, - model: Arc, - runtime: TokioRuntime, + sender: RuntimeSender, tokenizer: Arc, - slots: Mutex>, - caches: Mutex, + slots: Arc>>, + caches: Arc>, } -impl Runtime { - pub async fn new( - context: Context, - bundle: B, - reload: ReloadRequest, - states: Vec, - tokenizer: Tokenizer, - ) -> Self - where - J: Job + Send + 'static, - B: Dispatcher + Bundle + Clone + Send + 'static, - { - let slots = (0..reload.max_batch) - .map(|_| SlotState::default()) - .collect(); - - let mut caches = CacheHub::default(); - - // set up default initial state - if let Some(state) = states.iter().find(|state| state.default) { - caches.default.state = Some(state.clone()); - } - for state in states { - let id = state.id; - let item = Cache { - state: Some(state), - cache: Trie::new(), - }; - caches.backed.insert(id, item); - } - - let info = bundle.info(); - let state = Arc::new(bundle.state()); - let model = Arc::new(Model(bundle.model())); - let runtime = TokioRuntime::::new(bundle).await; - - Self { - context, - reload, - info, - state, - model, - runtime, - tokenizer: Arc::new(tokenizer), - slots: Mutex::new(slots), - caches: Mutex::new(caches), - } - } - - #[inline] - pub fn context(&self) -> &Context { - &self.context - } - - #[inline] - pub fn reload(&self) -> &ReloadRequest { - &self.reload - } - - #[inline] - pub fn info(&self) -> &ModelInfo { - &self.info - } - - #[inline] - pub fn num_batch(&self) -> usize { - self.state.num_batch() - } - - #[inline] - pub fn tokenizer(&self) -> Arc { - self.tokenizer.clone() - } - - pub async fn states(&self) -> Vec<(StateId, InitState)> { - let caches = self.caches.lock().await; - let mut states = vec![]; - - if let Some(state) = &caches.default.state { - states.push((state.id, state.clone())); - } - for item in caches.backed.values() { - if let Some(state) = &item.state { - states.push((state.id, state.clone())); - } - } - - states - } - - pub async fn load_init_state(&self, state: InitState) { - let mut caches = self.caches.lock().await; - caches.backed.insert( - state.id, - Cache { - state: Some(state), - cache: Trie::new(), - }, - ); - } - - pub async fn unload_init_state(&self, id: StateId) { - let mut caches = self.caches.lock().await; - caches.backed.remove(&id); - } - - pub async fn serialize_model(&self, path: PathBuf) -> Result<()> { - let model = self.model.clone(); - let handle = tokio::task::spawn_blocking(move || { - let file = std::fs::File::create(path)?; - model.serialize(file) - }); - handle.await? - } - +impl CoreRuntime { /// Search for the longest common prefix in the memory cache and checkout the state from that point. /// Should there be a cache miss, an initial state is returned. async fn checkout(&self, id: StateId, tokens: &[u16]) -> CacheCheckout { @@ -562,26 +418,19 @@ impl Runtime { } } - /// Compile and cache the given schema into a BNF sampler. - async fn compile_bnf_schema(&self, schema: String) -> Result { - BnfSampler::new(&self.tokenizer, &schema) - } - /// Queue an inference task. - pub async fn queue(&self, context: GenerateContext) -> Result { - let mut slots = self.slots.lock().await; - - let mut tokens = [context.prefix, context.suffix].concat(); - if tokens.is_empty() { - tokens.push(0); - } + async fn queue(&self, context: GenerateContext) -> Result { + let tokens = match [context.prefix, context.suffix].concat() { + tokens if tokens.is_empty() => vec![0u16], + tokens => tokens, + }; // compile the BNF schema. let mut formatters = Vec::>>::new(); if let Some(schema) = context.request.bnf_schema.clone() { - match self.compile_bnf_schema(schema).await { + match BnfSampler::new(&self.tokenizer, &schema) { Ok(bnf) => formatters.push(Arc::new(RwLock::new(bnf))), - Err(err) => return Ok(SlotResult::Error(err.to_string())), + Err(err) => return Ok(SlotResult::Error(err.into())), } } @@ -589,22 +438,32 @@ impl Runtime { // 1. find the slot that matches the context (continue) // 2. find an empty slot // 3. find the oldest non-empty slot - let choice = slots - .iter() - .enumerate() - .filter_map(|(batch, slot)| match slot { - SlotState::Idle(content, time) => { - let delta = time.elapsed().as_millis(); - match (content.is_empty(), tokens.starts_with(content)) { - (true, _) => Some((SlotChoice::Empty(batch), delta)), - (false, true) => Some((SlotChoice::Continue(batch, content.len()), delta)), - (false, false) => Some((SlotChoice::Back(batch), delta)), + let choice = { + let mut slots = self.slots.lock().await; + let choice = slots + .iter() + .enumerate() + .filter_map(|(batch, slot)| match slot { + SlotState::Idle(content, instant) => { + let delta = instant.elapsed(); + match (content.is_empty(), tokens.starts_with(content)) { + (true, _) => Some((SlotChoice::Empty(batch), delta)), + (_, true) => Some((SlotChoice::Continue(batch, content.len()), delta)), + (_, false) => Some((SlotChoice::Back(batch), delta)), + } } - } - _ => None, - }) - .max_by(|lhs, rhs| lhs.0.cmp(&rhs.0).then(lhs.1.cmp(&rhs.1))) - .map(|(x, _)| x); + _ => None, + }) + .max_by(|lhs, rhs| lhs.0.cmp(&rhs.0).then(lhs.1.cmp(&rhs.1))) + .map(|(x, _)| x); + match choice { + None => (), + Some(SlotChoice::Empty(batch)) + | Some(SlotChoice::Back(batch)) + | Some(SlotChoice::Continue(batch, _)) => slots[batch] = SlotState::Locked, + } + choice + }; match choice { // we cannot find a slot because all slots are occupied @@ -620,352 +479,301 @@ impl Runtime { )), // back a non-relative and non-empty slot and use it for our new context Some(SlotChoice::Back(batch)) => { - log::info!("start at non-empty slot {}", batch); + log::info!("[queue][back][slot: {batch}]"); let checkout = self.checkout(context.request.state, &tokens).await; - self.state.load(checkout.state, batch)?; + self.load(batch, checkout.state); + // self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); assert!(len == 0 || (len > 0 && checkout.output.is_some())); - log::info!("slot {} checks out cache of length {}", batch, len); - - let mut state = SlotState::Wait( - GenerateContext { - prefix: Tokens(tokens[..len].to_vec()), - suffix: Tokens(tokens[len..].to_vec()), - output: checkout.output, - formatters, - ..context - } - .into(), - ); + log::info!("[cache][checkout[[slot: {batch}][len: {len}]"); - std::mem::swap(&mut state, &mut slots[batch]); + let context = GenerateContext { + prefix: Tokens(tokens[..len].to_vec()), + suffix: Tokens(tokens[len..].to_vec()), + output: checkout.output, + formatters, + ..context + }; + let handle = tokio::spawn(self.clone().process(batch, context)); + let mut slots = self.slots.lock().await; + slots[batch] = SlotState::Busy(handle); Ok(SlotResult::Fault(batch)) } // directly occupy an empty slot so no need backing Some(SlotChoice::Empty(batch)) => { - log::info!("start at empty slot {}", batch); + log::info!("[queue][empty][slot: {batch}]"); let checkout = self.checkout(context.request.state, &tokens).await; - self.state.load(checkout.state, batch)?; + self.load(batch, checkout.state); + // self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); assert!(len == 0 || (len > 0 && checkout.output.is_some())); - log::info!("slot {} checks out cache of length {}", batch, len); - - let state = SlotState::Wait( - GenerateContext { - prefix: Tokens(tokens[..len].to_vec()), - suffix: Tokens(tokens[len..].to_vec()), - output: checkout.output, - formatters, - ..context - } - .into(), - ); - slots[batch] = state; + log::info!("[cache][checkout][slot: {batch}][len: {len}]"); + + let context = GenerateContext { + prefix: Tokens(tokens[..len].to_vec()), + suffix: Tokens(tokens[len..].to_vec()), + output: checkout.output, + formatters, + ..context + }; + let handle = tokio::spawn(self.clone().process(batch, context)); + let mut slots = self.slots.lock().await; + slots[batch] = SlotState::Busy(handle); Ok(SlotResult::Success(batch)) } - // continue from an existing slot Some(SlotChoice::Continue(batch, ..)) => { - log::info!("continue at slot {}", batch); + log::info!("[queue][continue][slot: {batch}]"); let checkout = self.checkout(context.request.state, &tokens).await; - self.state.load(checkout.state, batch)?; + self.load(batch, checkout.state); + // self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); assert!(len == 0 || (len > 0 && checkout.output.is_some())); - log::info!("slot {} checks out cache of length {}", batch, len); - - let state = SlotState::Wait( - GenerateContext { - prefix: Tokens(tokens[..len].to_vec()), - suffix: Tokens(tokens[len..].to_vec()), - output: checkout.output, - formatters, - ..context - } - .into(), - ); - slots[batch] = state; + log::info!("[cache][checkout[[slot: {batch}][len: {len}]"); + + let context = GenerateContext { + prefix: Tokens(tokens[..len].to_vec()), + suffix: Tokens(tokens[len..].to_vec()), + output: checkout.output, + formatters, + ..context + }; + let handle = tokio::spawn(self.clone().process(batch, context)); + let mut slots = self.slots.lock().await; + slots[batch] = SlotState::Busy(handle); Ok(SlotResult::Success(batch)) } } } - /// This critical section synchronizes `slots` and fills `payloads`. - async fn synchronize(&self, payloads: &mut [Payload]) -> Result<()> { - let mut slots = self.slots.lock().await; - - // synchronize payloads and slots: kill dead payloads - for (slot, payload) in slots.iter().zip(payloads.iter_mut()) { - if !(payload.is_empty() || matches!(slot, SlotState::Busy)) { - log::warn!("payload should either be empty or slot should be busy"); - *payload = Payload::Empty; + /// Reset finished slots to `idle`. Cache current states of finished slots. + async fn update(&self) { + let update = |handle: JoinHandle<_>| async move { + if !handle.is_finished() { + return Ok(SlotState::Busy(handle)); } - } - // reset all finished slots to idle - for (batch, payload) in payloads.iter_mut().enumerate() { - let Some(context) = payload.take() else { - continue; - }; + let context = handle.await??; + Ok::<_, Box>(SlotState::Idle(context.prefix, Instant::now())) + }; - let backed = self.state.back(batch).await?; - if let GenerateKind::Embed { layer } = context.request.kind { - let backed = backed.clone(); - let embed = match layer { - x if x < self.info.num_layer => self.state.embed(layer, backed)?.to_vec(), - _ => backed.to_vec(), + for batch in 0..self.reload.max_batch { + let handle = { + let mut slots = self.slots.lock().await; + let slot = std::mem::replace(&mut slots[batch], SlotState::Locked); + let SlotState::Busy(handle) = slot else { + slots[batch] = slot; + continue; }; - let _ = context.sender.send(Token::Embed(embed)); - } - - if let Some(output) = context.output { - let mut caches = self.caches.lock().await; - let cache = &mut caches.fetch(context.request.state).cache; - let item = CachedItem::new(backed, output); - let (item, _) = tokio::sync::watch::channel(Some(item)); - cache.insert(context.prefix.clone(), item); - log::info!( - "backed completed slot {} of length {}", - batch, - context.prefix.len() - ); - } - - assert!(matches!(slots[batch], SlotState::Busy)); - slots[batch] = SlotState::Idle(context.prefix, Instant::now()); - } - - // take data from some pending slots - let occupancy = payloads - .iter() - .filter(|x| matches!(x, Payload::Busy(_))) - .count(); - let remain = self.reload.max_batch - self.reload.max_batch.min(occupancy); - let batches = slots - .iter() - .enumerate() - .filter(|(_, slot)| matches!(slot, SlotState::Wait(_))) - .take(remain) - .map(|(batch, _)| batch) - .collect_vec(); - - for batch in batches { - let mut slot = SlotState::Busy; - std::mem::swap(&mut slots[batch], &mut slot); - - let SlotState::Wait(context) = slot else { - unreachable!() + handle }; - let mut context = *context; - // allocate a future cache slot - let mut caches = self.caches.lock().await; - let cache = &mut caches.fetch(context.request.state).cache; - - let enable = context.prompt_tokens.len() > MIN_PROMPT_CACHE_TOKENS; - let enable = enable && !cache.contains_key(context.prompt_tokens.as_token_slice()); - if enable { - let (sender, _) = tokio::sync::watch::channel(None); - context.prompt_cached = CachedPrompt::Future(sender.clone()); - cache.insert(Tokens(context.prompt_tokens.clone()), sender); - - log::info!( - "slot {} schedules future back of length {}", - batch, - context.prompt_tokens.len() - ); - } + let updated = match update(handle).await { + Ok(updated) => updated, + Err(err) => { + log::error!("[update][error][slot: {batch}] {err}"); + let mut slots = self.slots.lock().await; + slots[batch] = Default::default(); + continue; + } + }; - let _ = context.sender.send(Token::Start); - assert!(matches!(payloads[batch], Payload::Empty)); - payloads[batch] = Payload::Busy(context); + let mut slots = self.slots.lock().await; + slots[batch] = updated; } - - Ok(()) } - async fn sample(&self, payloads: &mut [Payload]) -> Result)>> { - // update raw outputs - let mut set = tokio::task::JoinSet::new(); - for (batch, payload) in payloads.iter().enumerate() { - let Payload::Busy(context) = payload else { - continue; - }; - - // in case that we have not yet read the whole prompt but still gets the output (from the cache) - if !context.suffix.is_empty() { - continue; + async fn sample( + &self, + output: TensorCpu, + sampler: Arc>, + formatters: Vec>>, + bias: Arc>, + ) -> Result<(u16, TensorCpu)> { + // process raw model outputs + let num_vocab = self.info.num_vocab; + let input = { + let mut data = output.to_vec(); + assert_eq!(data.len(), num_vocab); + + sampler.read().await.transform(&mut data); + for formatter in formatters { + formatter.read().await.transform(&mut data); + } + for (token, bias) in bias.iter() { + data[*token as usize] += *bias; } - let Some(output) = context.output.clone() else { - continue; - }; - - let num_vocab = self.info.num_vocab; - let formatters = context.formatters.clone(); - let sampler = context.request.sampler.clone(); - let bias = context.request.bias.clone(); - set.spawn(async move { - let mut data = output.to_vec(); - assert_eq!(data.len(), num_vocab); - - sampler.read().await.transform(&mut data); - for (token, bias) in bias.iter() { - data[*token as usize] += *bias; - } - for formatter in formatters { - formatter.read().await.transform(&mut data); - } - - (batch, data) - }); - } - let mut outputs = HashMap::new(); - while let Some(Ok((batch, data))) = set.join_next().await { - outputs.insert(batch, data); - } - let outputs = (0..payloads.len()) - .map(|batch| outputs.remove(&batch)) - .map(|data| match data { - Some(data) => TensorCpu::from_data([self.info.num_vocab, 1, 1, 1], data), - None => TensorCpu::from_data([self.info.num_vocab, 0, 1, 1], vec![]), - }) - .try_collect()?; + self.context.tensor_from_data([num_vocab, 1, 1, 1], data)? + }; // compute probabilities - let outputs = softmax(&self.context, outputs).await?; + let (sender, receiver) = flume::unbounded(); + let _ = self.sender.softmax.send(SoftmaxBatch { input, sender }); + let output = receiver.recv_async().await?; // sample tokens - let mut set = tokio::task::JoinSet::new(); - for (batch, (payload, output)) in payloads.iter_mut().zip(outputs.into_iter()).enumerate() { - let Payload::Busy(context) = payload else { - continue; - }; - - if output.is_empty() { - continue; - } - - let num_vocab = self.info.num_vocab; - let sampler = context.request.sampler.clone(); - set.spawn(async move { - let data = output.to_vec(); - assert_eq!(data.len(), num_vocab); - let token = sampler.write().await.sample(&data); - (batch, token, data) - }); - } - let mut tokens = HashMap::new(); - while let Some(Ok((batch, token, data))) = set.join_next().await { - tokens.insert(batch, (token, data)); - } - - Ok(tokens) + assert_eq!(output.len(), num_vocab); + let token = sampler.write().await.sample(&output); + Ok((token, output)) } - async fn compute_perplexities( - &self, - tokens: &Tokens, - batch: usize, - head: Option, - ) -> Result { - let mut probabilities = Vec::with_capacity(tokens.len()); + async fn perplexity(&self, batch: usize, tokens: &[u16], head: Option) -> Result { + let mut p = Vec::with_capacity(tokens.len().max(1)); + let len = tokens.len(); let tokens = match head { Some(head) => { - probabilities.push(head); - tokens.0.clone() + p.push(head); + tokens.to_vec() } - None => [vec![0], tokens.0.clone()].concat(), - }; - - // construct an inference session with only one batch - let mut batches = vec![InferInputBatch::default(); self.num_batch()]; - batches[batch] = InferInputBatch { - tokens: tokens.clone(), - option: InferOption::Full, + None => [&[0], tokens].concat(), }; - let inference = InferInput::new(batches, self.reload.token_chunk_size); - let mut inference = Some(inference); - let mut index = 1; - loop { - let input = inference.take().unwrap(); - if input.batches[batch].tokens.is_empty() { - break; - } - let (input, InferOutput(output)) = self.runtime.infer(input).await?; - inference.replace(input); - - let output = output[batch].0.clone().split(1)?; - for data in output { - if index < tokens.len() { - let data = data.map(|x| x.exp()).to_vec(); - let sum: f32 = data.iter().sum(); - let token = tokens[index] as usize; - probabilities.push(data[token] / sum); + let (sender, receiver) = flume::unbounded(); + let _ = self + .sender + .infer + .send_async({ + let tokens = tokens.clone(); + let option = InferOption::Full; + InferBatch::Run { + batch, + tokens, + option, + sender, } - index += 1; - } + }) + .await; + + let index = Arc::new(Mutex::new(1)); + while p.len() < len { + let tokens = tokens.clone(); + let output = receiver.recv_async().await?; + let output = output.split(1)?; + let f = { + let index = index.clone(); + move || { + let mut index = index.blocking_lock(); + let mut p = Vec::with_capacity(output.len()); + for data in output { + if *index < tokens.len() { + let data = data.map(|x| x.exp()).to_vec(); + let sum: f32 = data.iter().sum(); + let token = tokens[*index] as usize; + p.push(data[token] / sum); + } + *index += 1; + } + p + } + }; + let mut q = tokio::task::spawn_blocking(f).await?; + p.append(&mut q); } - let perplexity: f32 = probabilities.into_iter().map(|x| x.ln()).sum::(); - let perplexity = -perplexity / tokens.len() as f32; - Ok(perplexity) + let ppl: f32 = p.into_iter().map(|x| x.ln()).sum(); + let ppl = -ppl / tokens.len() as f32; + Ok(ppl) } - async fn finalize( - &self, - payloads: &mut [Payload], - tokens: HashMap)>, - ) -> Result<()> { - for (batch, payload) in payloads.iter_mut().enumerate() { - let Payload::Busy(context) = payload else { - continue; - }; + fn load(&self, batch: usize, tensor: TensorCpu) { + let _ = self.sender.infer.send(InferBatch::Load { batch, tensor }); + } - // in case that we have not yet read the whole prompt but still gets the output (from the cache) - if !context.suffix.is_empty() { - continue; - } + async fn back(&self, batch: usize) -> Result> { + let (sender, receiver) = flume::unbounded(); + let _ = self.sender.infer.send(InferBatch::Back { batch, sender }); + let backed = receiver.recv_async().await?; + Ok(backed) + } - let Some((token, data)) = tokens.get(&batch) else { - continue; - }; + /// Read in the prompt of a batch and continuously sample it until it is done. + async fn process(self, batch: usize, mut context: GenerateContext) -> Result { + // schedule a future cache slot for the prompt + { + let mut caches = self.caches.lock().await; + let cache = &mut caches.fetch(context.request.state).cache; - // cache the prompt if it is too long. - if let (CachedPrompt::Future(sender), Some(output)) = - (context.prompt_cached.clone(), context.output.clone()) - { - assert_eq!(context.prefix.len(), context.prompt_tokens.len()); - let backed = self.state.back(batch).await?; - sender.send_replace(Some(CachedItem::new(backed, output))); - context.prompt_cached = CachedPrompt::Done; + let enable = context.prompt_tokens.len() > MIN_PROMPT_CACHE_TOKENS; + let enable = enable && !cache.contains_key(context.prompt_tokens.as_token_slice()); + if enable { + let (sender, _) = tokio::sync::watch::channel(None); + context.prompt_cached = CachedPrompt::Future(sender.clone()); + cache.insert(Tokens(context.prompt_tokens.clone()), sender); - log::info!( - "backed prompt of slot {} of length {}", - batch, - context.prefix.len() - ); + let len = context.prompt_tokens.len(); + log::info!("[cache][future][slot: {batch}][len: {len}]"); } + } - let token = *token; - let mut stop_token = token == 0; + let _ = context.sender.send(Token::Start); - assert_eq!(context.suffix.len(), 0); - context.suffix.0.push(token); + loop { + let output = match (context.suffix.len(), context.output.clone()) { + (0, Some(output)) => output, + _ => { + let (sender, receiver) = flume::unbounded(); + let _ = self + .sender + .infer + .send_async(InferBatch::Run { + batch, + tokens: context.suffix.to_vec(), + option: InferOption::Last, + sender, + }) + .await; + + let prefix = std::mem::take(&mut context.prefix); + let suffix = std::mem::take(&mut context.suffix); + + context.prefix = Tokens([prefix.0, suffix.0].concat()); + context.suffix = Tokens(vec![]); + + let output = receiver.recv_async().await?; + + // cache the prompt if being asked + if let CachedPrompt::Future(sender) = context.prompt_cached.clone() { + assert_eq!(context.prefix.len(), context.prompt_tokens.len()); + + let backed = self.back(batch).await?; + let output = output.clone(); + sender.send_replace(Some(CachedItem::new(backed, output))); + context.prompt_cached = CachedPrompt::Done; + + let len = context.prefix.len(); + log::info!("[cache][insert][slot: {batch}][len: {len}]"); + } + + output + } + }; + let (token, output) = { + let output = output.clone(); + let sampler = context.request.sampler.clone(); + let formatters = context.formatters.clone(); + let bias = context.request.bias.clone(); + self.sample(output, sampler, formatters, bias).await? + }; + + let mut stop_token = token == 0; let mut word = match self.tokenizer.decode(&[token]) { Ok(word) => word, Err(err) => { - log::warn!("{err}"); + log::warn!("[process][error] {err}"); stop_token = true; - Default::default() + Vec::new() } }; - context.model_text.append(&mut word.clone()); - context.buffer.append(&mut word); + + context.output = Some(output.clone()); + context.suffix.0.push(token); context.model_tokens.push(token); + context.model_text.extend(&word); + context.buffer.append(&mut word); let instant = context.instant.get_or_insert(Instant::now()); let mut done = false; @@ -1033,9 +841,8 @@ impl Runtime { if context.sender.is_disconnected() { done = true; } else if let GenerateKind::Choose { calibrate, .. } = context.request.kind { - // calculate perplexities for choose request - let backed = self.state.read(batch)?; - let mut perplexities = vec![f32::INFINITY; context.choices.len()]; + let backed = self.back(batch).await?; + let mut ppl = vec![f32::INFINITY; context.choices.len()]; if calibrate { // compute perplexities of the choices themselves and calibrate their effects @@ -1046,12 +853,11 @@ impl Runtime { .enumerate() .filter(|(_, choice)| !choice.is_empty()) { - self.state.load(init.clone(), batch)?; - let perplexity = -self.compute_perplexities(choice, batch, None).await?; - perplexities[index] = perplexity; + self.load(batch, init.clone()); + ppl[index] = -self.perplexity(batch, choice, None).await?; } // recover the state - self.state.write(backed.clone(), batch)?; + self.load(batch, backed.clone()); } for (index, choice) in context @@ -1060,23 +866,40 @@ impl Runtime { .enumerate() .filter(|(_, choice)| !choice.is_empty()) { - let perplexity = self - .compute_perplexities(choice, batch, Some(data[choice[0] as usize])) - .await?; - perplexities[index] = match calibrate { - true => perplexities[index] + perplexity, - false => perplexity, + let output = output.clone().to_vec(); + let head = Some(output[choice[0] as usize]); + let p = self.perplexity(batch, choice, head).await?; + ppl[index] = match calibrate { + true => ppl[index] + p, + false => p, }; - // recover the state - self.state.write(backed.clone(), batch)?; + self.load(batch, backed.clone()); } - let _ = context.sender.send(Token::Choose(perplexities)); + + let _ = context.sender.send(Token::Choose(ppl)); + done = true; + } else if let GenerateKind::Embed { .. } = context.request.kind { + let backed = self.back(batch).await?; + let embed = backed.to_vec(); + let _ = context.sender.send(Token::Embed(embed)); done = true; } else if halt || stop_matched || stop_token { let output = String::from_utf8_lossy(head); let _ = context.sender.send(Token::Content(output.into())); stop(FinishReason::Stop); + + if let Some(output) = context.output.clone() { + let backed = self.back(batch).await?; + let mut caches = self.caches.lock().await; + let cache = &mut caches.fetch(context.request.state).cache; + let item = CachedItem::new(backed, output); + let (item, _) = tokio::sync::watch::channel(Some(item)); + cache.insert(context.prefix.clone(), item); + + let len = context.prefix.len(); + log::info!("[cache][insert][slot: {batch}][len: {len}]"); + } } else if context.model_tokens.len() >= context.request.max_tokens { stop(FinishReason::Length); } else if let Ok(word) = String::from_utf8(head.to_vec()) { @@ -1084,114 +907,240 @@ impl Runtime { context.buffer = tail.to_vec(); } - done.then(|| payload.finalize()); + if done { + log::info!("[process][done][slot: {batch}]"); + break; + } } - Ok(()) + Ok(context) + } + + /// Keep the items in the cache less then [`MAX_CACHE_ITEMS`]. + async fn maintain_cache(&self) { + let mut caches = self.caches.lock().await; + caches.default.maintain(); + caches.backed.iter_mut().for_each(|(_, x)| x.maintain()); } +} - async fn process(&self, payloads: &mut [Payload]) -> Result<()> { - let tokens = self.sample(payloads).await?; - self.finalize(payloads, tokens).await?; - self.synchronize(payloads).await?; +async fn enqueue(runtime: CoreRuntime, receiver: Receiver, timer: Duration) { + let mut queue = Vec::::new(); - let option = InferOption::Last; - let batches = payloads - .iter() - .map(|payload| match payload { - Payload::Busy(context) => context.suffix.0.clone(), - _ => vec![], - }) - .map(|tokens| InferInputBatch { tokens, option }) - .collect(); - let inference = InferInput::new(batches, self.reload.token_chunk_size); - if inference.num_token() == 0 { - return Ok(()); - } - let mut inference = Some(inference); + 'outer: while let Ok(context) = receiver.recv_async().await { + queue.push(context); - // run the model until there is at least one slot finished - let outputs = loop { - let input = inference.take().unwrap(); - let (input, output) = self.runtime.infer(input).await?; - inference.replace(input); + 'inner: loop { + runtime.maintain_cache().await; + runtime.update().await; - if output.iter().any(|batch| batch.len() > 0) { - break output; + let mut temp = Vec::new(); + for context in queue.drain(..) { + match runtime.queue(context).await { + Ok(SlotResult::Failure(context)) => temp.push(*context), + Ok(SlotResult::Success(batch)) => log::info!("[enqueue][ok][slot: {batch}]"), + Ok(SlotResult::Fault(batch)) => log::info!("[enqueue][fault][slot: {batch}]"), + Ok(SlotResult::Error(err)) => log::error!("[enqueue][error] {err}"), + Err(err) => log::error!("[enqueue][error] {err}"), + } + } + std::mem::swap(&mut queue, &mut temp); + + if queue.is_empty() { + break 'inner; } + + match receiver.try_recv() { + Ok(context) => queue.push(context), + Err(TryRecvError::Empty) => tokio::time::sleep(timer).await, + Err(TryRecvError::Disconnected) => break 'outer, + } + } + } +} + +async fn finalize(runtime: CoreRuntime, receiver: Receiver, timer: Duration) { + while !receiver.is_disconnected() { + runtime.maintain_cache().await; + runtime.update().await; + tokio::time::sleep(timer).await; + } +} + +async fn infer( + reload: Arc, + runtime: Weak, + state: Arc, + receiver: Receiver, +) -> Result<()> { + let mut senders = vec![None; reload.max_batch]; + let batches = vec![ + InferInputBatch { + tokens: vec![], + option: Default::default(), }; + reload.max_batch + ]; + let inference = InferInput::new(batches, reload.token_chunk_size); + let mut inference = Some(inference); + + async fn schedule( + inference: &mut Option, + senders: &mut [Option>>], + state: Arc, + batch: InferBatch, + ) -> Result<()> { + match batch { + InferBatch::Run { + batch, + tokens, + option, + sender, + } => { + let mut input = inference.take().expect("inference must not be `None`"); + input.batches[batch] = InferInputBatch { tokens, option }; + senders[batch] = Some(sender); + inference.replace(input); + } + InferBatch::Load { batch, tensor } => state.load(tensor, batch)?, + InferBatch::Back { batch, sender } => { + let backed = state.back(batch).await?; + let _ = sender.send_async(backed).await; + } + } + Ok(()) + } - for (payload, output) in payloads.iter_mut().zip(outputs.iter()) { - let Payload::Busy(context) = payload else { - continue; - }; + 'outer: while let Ok(batch) = receiver.recv_async().await { + schedule(&mut inference, &mut senders, state.clone(), batch).await?; - // if the suffix is empty, the output is read from the cache, and we don't want to override it. - if context.suffix.is_empty() { - continue; + while inference + .as_ref() + .map(|input| input.num_token() > 0) + .expect("inference must not be `None`") + { + 'inner: loop { + let state = state.clone(); + match receiver.try_recv() { + Ok(batch) => schedule(&mut inference, &mut senders, state, batch).await?, + Err(TryRecvError::Empty) => break 'inner, + Err(TryRecvError::Disconnected) => break 'outer, + } } - context.output = match output.len() { - 0 => None, - x if x == self.info.num_vocab => Some(output.0.clone()), - x => unreachable!("output size should not be {x}"), + let Some(runtime) = runtime.upgrade() else { + break 'outer; }; + let input = inference.take().expect("inference must not be `None`"); + let (input, output) = runtime.infer(input).await?; + inference.replace(input); + + for (InferOutputBatch(output), sender) in output + .iter() + .zip_eq(senders.clone().into_iter()) + .filter(|(output, _)| !output.is_empty()) + .filter_map(|(output, sender)| sender.map(|sender| (output, sender))) + { + let _ = sender.send_async(output.clone()).await; + } } + } - let inference = inference.unwrap(); - for (payload, input) in payloads.iter_mut().zip(inference.batches.into_iter()) { - let Payload::Busy(context) = payload else { - continue; - }; + log::info!("[infer] exit"); + Ok(()) +} + +async fn softmax( + reload: Arc, + context: Context, + receiver: Receiver, +) -> Result<()> { + let mut batches = Vec::with_capacity(reload.max_batch); - let prefix = std::mem::take(&mut context.prefix); - let suffix = std::mem::take(&mut context.suffix); - let model_tokens = [prefix.0, suffix.0].concat(); + while let Ok(batch) = receiver.recv_async().await { + batches.push(batch); - // compute new prefix and suffix using the current remaining tokens - assert!(model_tokens.len() >= input.tokens.len()); - let len = model_tokens.len() - input.tokens.len(); - context.prefix = Tokens(model_tokens[..len].to_vec()); - context.suffix = Tokens(model_tokens[len..].to_vec()); + for batch in receiver.drain() { + batches.push(batch); } - Ok(()) - } + let input = batches.iter().map(|batch| batch.input.clone()).collect(); + let output = web_rwkv::runtime::softmax::softmax(&context, input).await?; - /// Keep the items in the cache less then [`MAX_CACHE_ITEMS`]. - async fn maintain_cache(&self) { - let mut caches = self.caches.lock().await; - caches.default.maintain(); - caches.backed.iter_mut().for_each(|(_, x)| x.maintain()); - } -} + for (batch, tensor) in batches.iter().zip_eq(output.into_iter()) { + let _ = batch.sender.send_async(tensor).await; + } -pub async fn run(receiver: Receiver<()>, env: Arc>) { - { - // this task constantly runs, cleaning up state cache - let env = env.clone(); - tokio::spawn(async move { - loop { - if let Environment::Loaded(runtime) = &*env.read().await { - runtime.maintain_cache().await; - } - tokio::time::sleep(Duration::from_secs(1)).await; - } - }); + batches.clear(); } - while let Ok(()) = receiver.recv_async().await { - if let Environment::Loaded(runtime) = &*env.read().await { - let mut payloads = vec![Payload::default(); runtime.num_batch()]; - 'run: loop { - if let Err(err) = runtime.process(&mut payloads).await { - log::error!("{}", err); - break 'run; - } - if payloads.iter().all(Payload::is_empty) { - break 'run; - } - } + log::info!("[softmax] exit"); + Ok(()) +} + +pub async fn run( + context: Context, + runtime: Weak, + state: Arc, + receiver: Receiver, + RuntimeInfo { + reload, + info, + states, + tokenizer, + .. + }: RuntimeInfo, +) { + let slots = std::iter::repeat_with(Default::default) + .take(reload.max_batch) + .collect(); + let slots = Arc::new(Mutex::new(slots)); + + let caches = { + let mut caches = CacheHub::default(); + // set up default initial state + if let Some(state) = states.iter().find(|state| state.default) { + caches.default.state = Some(state.clone()); + } + // set up other initial states with ids + for state in states { + let id = state.id; + let item = Cache { + state: Some(state), + cache: Trie::new(), + }; + caches.backed.insert(id, item); + } + Arc::new(Mutex::new(caches)) + }; + + let max_batch = reload.max_batch; + let runtime = { + let infer = { + let (sender, receiver) = flume::unbounded(); + tokio::spawn(infer(reload.clone(), runtime, state.clone(), receiver)); + sender + }; + let softmax = { + let (sender, receiver) = flume::unbounded(); + tokio::spawn(softmax(reload.clone(), context.clone(), receiver)); + sender + }; + let sender = RuntimeSender { infer, softmax }; + CoreRuntime { + context, + info, + reload, + state, + sender, + tokenizer, + slots, + caches, } + }; + let timer = Duration::from_secs_f32(1.0); + for _ in 0..max_batch { + tokio::spawn(enqueue(runtime.clone(), receiver.clone(), timer)); } + tokio::spawn(finalize(runtime, receiver, timer)); } diff --git a/crates/ai00-server/src/api/model.rs b/crates/ai00-server/src/api/model.rs index 68ef7aec..4909d950 100644 --- a/crates/ai00-server/src/api/model.rs +++ b/crates/ai00-server/src/api/model.rs @@ -1,7 +1,7 @@ +use std::sync::Arc; + use ai00_core::{ - reload::State, - run::{InitState, StateId}, - ReloadRequest, RuntimeInfo, SaveRequest, ThreadRequest, + reload::State, InitState, ReloadRequest, RuntimeInfo, SaveRequest, StateId, ThreadRequest, }; use futures_util::StreamExt; use salvo::{oapi::extract::JsonBody, prelude::*}; @@ -13,7 +13,7 @@ use crate::{build_path, types::ThreadSender, SLEEP}; #[derive(Debug, Clone, Serialize)] pub struct InfoResponse { - reload: ReloadRequest, + reload: Arc, model: ModelInfo, states: Vec, } @@ -30,13 +30,13 @@ pub async fn info(depot: &mut Depot) -> Json { let sender = depot.obtain::().unwrap(); let RuntimeInfo { reload, - model, + info: model, states, .. } = request_info(sender.to_owned(), SLEEP).await; let states = states .into_iter() - .map(|(id, InitState { name, .. })| InitStateInfo { id, name }) + .map(|InitState { name, id, .. }| InitStateInfo { id, name }) .collect(); Json(InfoResponse { reload, @@ -58,13 +58,13 @@ pub async fn state(depot: &mut Depot, res: &mut Response) { let stream = info_receiver.into_stream().map( |RuntimeInfo { reload, - model, + info: model, states, .. }| { let states = states .into_iter() - .map(|(id, InitState { name, .. })| InitStateInfo { id, name }) + .map(|InitState { name, id, .. }| InitStateInfo { id, name }) .collect(); match serde_json::to_string(&InfoResponse { reload, diff --git a/crates/ai00-server/src/api/oai/chat.rs b/crates/ai00-server/src/api/oai/chat.rs index 530d9e95..2b06ab6c 100644 --- a/crates/ai00-server/src/api/oai/chat.rs +++ b/crates/ai00-server/src/api/oai/chat.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use ai00_core::{ - run::StateId, FinishReason, GenerateRequest, ThreadRequest, Token, TokenCounter, MAX_TOKENS, + FinishReason, GenerateRequest, StateId, ThreadRequest, Token, TokenCounter, MAX_TOKENS, }; use derivative::Derivative; use futures_util::StreamExt; diff --git a/crates/ai00-server/src/api/oai/choose.rs b/crates/ai00-server/src/api/oai/choose.rs index b44c5c60..4d5b3957 100644 --- a/crates/ai00-server/src/api/oai/choose.rs +++ b/crates/ai00-server/src/api/oai/choose.rs @@ -1,4 +1,4 @@ -use ai00_core::{run::StateId, GenerateKind, GenerateRequest, ThreadRequest, Token}; +use ai00_core::{GenerateKind, GenerateRequest, StateId, ThreadRequest, Token}; use futures_util::StreamExt; use itertools::Itertools; use salvo::{ diff --git a/crates/ai00-server/src/api/oai/completion.rs b/crates/ai00-server/src/api/oai/completion.rs index 46115066..f6de9d4a 100644 --- a/crates/ai00-server/src/api/oai/completion.rs +++ b/crates/ai00-server/src/api/oai/completion.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use ai00_core::{ - run::StateId, FinishReason, GenerateRequest, ThreadRequest, Token, TokenCounter, MAX_TOKENS, + FinishReason, GenerateRequest, StateId, ThreadRequest, Token, TokenCounter, MAX_TOKENS, }; use derivative::Derivative; use futures_util::StreamExt; diff --git a/crates/ai00-server/src/api/oai/embedding.rs b/crates/ai00-server/src/api/oai/embedding.rs index f2959ad9..1383d20d 100644 --- a/crates/ai00-server/src/api/oai/embedding.rs +++ b/crates/ai00-server/src/api/oai/embedding.rs @@ -1,4 +1,4 @@ -use ai00_core::{run::StateId, GenerateKind, GenerateRequest, ThreadRequest, Token, TokenCounter}; +use ai00_core::{GenerateKind, GenerateRequest, StateId, ThreadRequest, Token, TokenCounter}; use futures_util::StreamExt; use salvo::{ oapi::{extract::JsonBody, ToParameters, ToResponse, ToSchema}, diff --git a/crates/ai00-server/src/main.rs b/crates/ai00-server/src/main.rs index 37dbbb60..190308e3 100644 --- a/crates/ai00-server/src/main.rs +++ b/crates/ai00-server/src/main.rs @@ -5,7 +5,7 @@ use std::{ time::Duration, }; -use ai00_core::{model_route, ThreadRequest}; +use ai00_core::ThreadRequest; use anyhow::{anyhow, bail, Result}; use clap::{command, CommandFactory, Parser}; use memmap2::Mmap; @@ -153,7 +153,7 @@ async fn main() { log::info!("{}\tversion: {}", bin_name, version); let (sender, receiver) = flume::unbounded::(); - tokio::spawn(model_route(receiver)); + tokio::spawn(ai00_core::serve(receiver)); let config = { let path = args