Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions LLama.Examples/Examples/BatchedExecutorSimple.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Diagnostics.CodeAnalysis;
using System.Text;
using LLama.Batched;
using LLama.Common;
Expand Down Expand Up @@ -34,6 +33,7 @@ public static async Task Run()
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// A set of questions to evaluate all at once
var messages = new[]
{
"What's 2+2?",
Expand All @@ -46,8 +46,10 @@ public static async Task Run()
"I have two sons, Bert and Ernie. What should I name my daughter?",
"What day comes after Friday?",
"What color shoes should I wear with dark blue pants?",
"Wy ae cts btr tn dgs?"
};

// Create a "Conversation" for each question
var conversations = new List<ConversationData>();
foreach (var message in messages)
{
Expand All @@ -57,11 +59,14 @@ public static async Task Run()
template.Add("user", message);
template.AddAssistant = true;
var templatedMessage = Encoding.UTF8.GetString(template.Apply());

// create a new conversation and prompt it. include special and bos because we are using the template
// - BOS is the "Beginning of Sequence" token and should be included at the start of any prompt
// - Special tokens are special non-text tokens which an LLM is trained to understand (e.g. BOS). The templated text may contains special tokens.
var conversation = executor.Create();
conversation.Prompt(executor.Context.Tokenize(templatedMessage, addBos: true, special: true));

// Store everything we need to process this conversation
conversations.Add(new ConversationData {
Prompt = message,
Conversation = conversation,
Expand All @@ -73,47 +78,64 @@ public static async Task Run()
var table = BuildTable(conversations);
await AnsiConsole.Live(table).StartAsync(async ctx =>
{
// Enter a loop generating tokens
for (var i = 0; i < TokenCount; i++)
{
// Run inference for all conversations in the batch which have pending tokens.
var decodeResult = await executor.Infer();

// Inference can fail, always check the return value!
// NoKvSlot is not a fatal error, it just means that there's not enough memory available in the KV cache to process everything. You can force
// this to happen by setting a small value for ContextSize in the ModelParams at the top of this file (e.g. 512).
// In this case it's handled by ending a conversation (which will free up some space) and trying again. You could also handle this by
// saving the conversation to disk and loading it up again later once some other conversations have finished.
if (decodeResult == DecodeResult.NoKvSlot)
throw new Exception("Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context.");
{
conversations.FirstOrDefault(a => !a.IsComplete)?.MarkComplete(failed:true);
continue;
}

// A generic error, this is fatal and the batch can no longer be used. This should never occur and generally indicates
// a bug in LLamaSharp, llama.cpp or a hardware error.
if (decodeResult == DecodeResult.Error)
throw new Exception("Unknown error occurred while inferring.");

foreach (var conversationData in conversations.Where(c => c.IsComplete == false))
// After inference all of the conversations must be sampled before running inference again.
foreach (var conversationData in conversations)
{
if (conversationData.Conversation.RequiresSampling == false) continue;

// sample a single token for the executor, passing the sample index of the conversation
var token = conversationData.Sampler.Sample(
executor.Context.NativeHandle,
conversationData.Conversation.GetSampleIndex());

// Completed conversations don't need sampling.
if (conversationData.IsComplete)
continue;

// If the conversation wasn't prompted before the last call to Infer then it won't need sampling.
if (!conversationData.Conversation.RequiresSampling)
continue;

// Use the sampling pipeline to choose a single token for this conversation.
var token = conversationData.Conversation.Sample(conversationData.Sampler);

// Some special tokens indicate that this sequence has ended. Check if that's what has been chosen by the sampling pipeline.
if (modelTokens.IsEndOfGeneration(token))
{
conversationData.MarkComplete();
}
else
{
// it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
// It isn't the end of generation, so add this token to the decoder and then add that to our tracked data
conversationData.Decoder.Add(token);
conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));

// add the token to the conversation
// Prompt the conversation with this token, ready for the next round of inference to generate another token
conversationData.Conversation.Prompt(token);
}
}

// render the current state
// Render the current state
table = BuildTable(conversations);
ctx.UpdateTarget(table);

if (conversations.All(c => c.IsComplete))
{
break;
}
}

// if we ran out of tokens before completing just mark them as complete for rendering purposes.
Expand Down Expand Up @@ -152,20 +174,23 @@ public class ConversationData
public required BaseSamplingPipeline Sampler { get; init; }
public required StreamingTokenDecoder Decoder { get; init; }

public string AnswerMarkdown => IsComplete
? $"[green]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"
: $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]";
public string AnswerMarkdown =>
IsComplete
? $"[{(IsFailed ? "red" : "green")}]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"
: $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]";

public bool IsComplete { get; private set; }
public bool IsFailed { get; private set; }

// we are only keeping track of the answer in two parts to render them differently.
private (string Message, string LatestToken) _inProgressAnswer = (string.Empty, string.Empty);

public void AppendAnswer(string newText) => _inProgressAnswer = (_inProgressAnswer.Message + _inProgressAnswer.LatestToken, newText);

public void MarkComplete()
public void MarkComplete(bool failed = false)
{
IsComplete = true;
IsFailed = failed;
if (Conversation.IsDisposed == false)
{
// clean up the conversation and sampler to release more memory for inference.
Expand Down
13 changes: 13 additions & 0 deletions LLama/Batched/ConversationExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Batched;

Expand All @@ -20,6 +21,18 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler
return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
}

/// <summary>
/// Sample a token from this conversation using the given sampling pipeline
/// </summary>
/// <param name="conversation"><see cref="Conversation"/> to sample from</param>
/// <param name="sampler"></param>
/// <param name="offset">Offset from the end of the conversation to the logits to sample, see <see cref="Conversation.GetSampleIndex"/> for more details</param>
/// <returns></returns>
public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0)
{
return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
}

/// <summary>
/// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ public void Clear()
#endregion

/// <summary>
/// Apply the template to the messages and write it into the output buffer
/// Apply the template to the messages and return a span containing the results
/// </summary>
/// <returns>A span over the buffer that holds the applied template</returns>
public ReadOnlySpan<byte> Apply()
Expand Down
55 changes: 23 additions & 32 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace LLama.Native;

Expand All @@ -17,11 +19,6 @@ public class LLamaBatch
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;

/// <summary>
/// Keep track of the index of existing token/position combos in the batch
/// </summary>
private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new();

/// <summary>
/// Keep a list of where logits can be sampled from
/// </summary>
Expand Down Expand Up @@ -105,6 +102,25 @@ private void GrowMaxSequences(int atLeast)

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
// Sanity checking
#if DEBUG
// Check every output logit position is generating logits for exactly one sequence
foreach (var (seq, idx) in _logitPositions)
{
Debug.Assert(_logits[idx] != 0);
Debug.Assert(_sequenceIdCount[idx] == 1);
Debug.Assert(_sequenceIds[idx][0] == seq);
}

// Check every index, if it's generating logits it must be in the _logitPositions list. Otherwise it must not.
for (var i = 0; i < _logits.Length; i++)
{
var actual = _logitPositions.Any(x => x.Item2 == i);
var expected = _logits[i] != 0;
Debug.Assert(actual == expected);
}
#endif

// This group holds all of the memory pins
var group = new GroupDisposable();

Expand Down Expand Up @@ -146,36 +162,12 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
// Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
// sequence ID to the list.
// Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
if (!logits && _index.TryGetValue((token, pos), out var existingIndex))
{
if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);

foreach (var sequence in sequences)
{
_sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence;
_sequenceIdCount[existingIndex]++;
}

return existingIndex;
}

// Couldn't find this token/position combo anywhere in the batch. Add a new item.

// Grow capacity as necessary
if (TokenCount == TokenCapacity)
GrowTokenCapacity();
if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length);

// Store the position in the index, so it can be found later.
// We need to check that it's not already there in case we skipped the check above (because logits is true).
if (!_index.ContainsKey((token, pos)))
_index.Add((token, pos), TokenCount);

// Add the items to the arrays
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;
Expand Down Expand Up @@ -213,15 +205,15 @@ public int Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
var rented = ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
try
{
sequences.CopyTo(rented, 0);
return Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
System.Buffers.ArrayPool<LLamaSeqId>.Shared.Return(rented);
ArrayPool<LLamaSeqId>.Shared.Return(rented);
}
#endif
}
Expand Down Expand Up @@ -273,7 +265,6 @@ public void Clear()
{
TokenCount = 0;

_index.Clear();
_logitPositions.Clear();
}

Expand Down