Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ pub enum GenerateKind {
/// The (reversed) number of layer at which the output is as embedding.
Embed { layer: usize },
/// Choose options by perplexity.
Choose { choices: Vec<String> },
Choose {
choices: Vec<String>,
calibrate: bool,
},
}

#[derive(Clone, Derivative)]
Expand Down Expand Up @@ -670,7 +673,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
request.sampler.write().await.init(&model_tokens);

let choices = match &request.kind {
GenerateKind::Choose { choices } => {
GenerateKind::Choose { choices, .. } => {
let choices: Vec<_> = choices
.iter()
.map(|prompt| tokenizer.encode(prompt.as_bytes()))
Expand Down
121 changes: 80 additions & 41 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,56 @@ impl Runtime {
Ok(tokens)
}

async fn compute_perplexities(
&self,
tokens: &Tokens,
batch: usize,
head: Option<f32>,
) -> Result<f32> {
let mut probabilities = Vec::with_capacity(tokens.len());
let tokens = match head {
Some(head) => {
probabilities.push(head);
tokens.0.clone()
}
None => [vec![0], tokens.0.clone()].concat(),
};

// construct an inference session with only one batch
let mut batches = vec![InferInputBatch::default(); self.num_batch()];
batches[batch] = InferInputBatch {
tokens: tokens.clone(),
option: InferOption::Full,
};
let inference = InferInput::new(batches, self.reload.token_chunk_size);
let mut inference = Some(inference);

let mut index = 1;
loop {
let input = inference.take().unwrap();
if input.batches[batch].tokens.is_empty() {
break;
}
let (input, InferOutput(output)) = self.runtime.infer(input).await?;
inference.replace(input);

let output = output[batch].0.clone().split(1)?;
for data in output {
if index < tokens.len() {
let data = data.map(|x| x.exp()).to_vec();
let sum: f32 = data.iter().sum();
let token = tokens[index] as usize;
probabilities.push(data[token] / sum);
}
index += 1;
}
}

let perplexity: f32 = probabilities.into_iter().map(|x| x.ln()).sum::<f32>();
let perplexity = -perplexity / tokens.len() as f32;
Ok(perplexity)
}

async fn finalize(
&self,
payloads: &mut [Payload],
Expand Down Expand Up @@ -982,52 +1032,41 @@ impl Runtime {

if context.sender.is_disconnected() {
done = true;
} else if matches!(context.request.kind, GenerateKind::Choose { .. }) {
} else if let GenerateKind::Choose { calibrate, .. } = context.request.kind {
// calculate perplexities for choose request
let backed = self.state.read(batch)?;
let mut perplexities = Vec::with_capacity(context.choices.len());
for choice in &context.choices {
if choice.is_empty() {
perplexities.push(f32::INFINITY);
continue;
let mut perplexities = vec![f32::INFINITY; context.choices.len()];

if calibrate {
// compute perplexities of the choices themselves and calibrate their effects
let init = self.state.init();
for (index, choice) in context
.choices
.iter()
.enumerate()
.filter(|(_, choice)| !choice.is_empty())
{
self.state.load(init.clone(), batch)?;
let perplexity = -self.compute_perplexities(choice, batch, None).await?;
perplexities[index] = perplexity;
}
// recover the state
self.state.write(backed.clone(), batch)?;
}

let mut probabilities = Vec::with_capacity(choice.len());
probabilities.push(data[choice[0] as usize]);

// construct an inference session with only one batch
let mut batches = vec![InferInputBatch::default(); self.num_batch()];
batches[batch] = InferInputBatch {
tokens: choice.0.clone(),
option: InferOption::Full,
for (index, choice) in context
.choices
.iter()
.enumerate()
.filter(|(_, choice)| !choice.is_empty())
{
let perplexity = self
.compute_perplexities(choice, batch, Some(data[choice[0] as usize]))
.await?;
perplexities[index] = match calibrate {
true => perplexities[index] + perplexity,
false => perplexity,
};
let inference = InferInput::new(batches, self.reload.token_chunk_size);
let mut inference = Some(inference);

let mut index = 1;
loop {
let input = inference.take().unwrap();
if input.batches[batch].tokens.is_empty() {
break;
}
let (input, InferOutput(output)) = self.runtime.infer(input).await?;
inference.replace(input);

let output = output[batch].0.clone().split(1)?;
for data in output {
if index < choice.len() {
let data = data.map(|x| x.exp()).to_vec();
let sum: f32 = data.iter().sum();
let token = choice[index] as usize;
probabilities.push(data[token] / sum);
}
index += 1;
}
}

let perplexity: f32 = probabilities.into_iter().map(|x| x.ln()).sum::<f32>();
let perplexity = -perplexity / choice.len() as f32;
perplexities.push(perplexity);

// recover the state
self.state.write(backed.clone(), batch)?;
Expand Down
5 changes: 4 additions & 1 deletion crates/ai00-server/src/api/oai/choose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ use crate::{
" San Francisco",
" Shanghai"
],
"calibrate": false,
"state": "00000000-0000-0000-0000-000000000000"
})
))]
pub struct ChooseRequest {
input: Array<String>,
choices: Vec<String>,
calibrate: bool,
state: StateId,
}

Expand All @@ -40,12 +42,13 @@ impl From<ChooseRequest> for GenerateRequest {
let ChooseRequest {
input,
choices,
calibrate,
state,
} = value;
Self {
prompt: Vec::from(input).join(""),
max_tokens: 1,
kind: GenerateKind::Choose { choices },
kind: GenerateKind::Choose { choices, calibrate },
state,
..Default::default()
}
Expand Down
Loading