Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions LLama/Batched/LLamaContextExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.IO;
using System.IO.MemoryMappedFiles;
using LLama.Native;
Expand All @@ -24,20 +25,20 @@ internal static void SaveState(this LLamaContext context, string filename, LLama
if (File.Exists(filename))
File.Delete(filename);

// Estimate size of state to write to disk, this is always equal to or greater than the actual size
var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence));
// Get the exact size of the state
var stateSize = context.NativeHandle.GetStateSize(sequence);

// Space for "extra" byte plus a 8 byte header
var prefixSize = header.Length + 8;

// Add enough space for the "extra" data and a 6 byte header
var totalFileSize = prefixSize + estimatedStateSize;
var totalFileSize = (nuint)prefixSize + stateSize;

// Map the file and write the bytes directly to it.
long writtenBytes = 0;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize))
nuint writtenBytes = 0;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, (long)totalFileSize))
{
using (var view = file.CreateViewAccessor(0, totalFileSize))
using (var view = file.CreateViewAccessor(0, (long)totalFileSize))
{
unsafe
{
Expand All @@ -51,10 +52,10 @@ internal static void SaveState(this LLamaContext context, string filename, LLama
BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), (uint)header.Length);
writtenBytes += 4;
header.CopyTo(new Span<byte>(ptr + writtenBytes, header.Length));
writtenBytes += header.Length;
writtenBytes += (nuint)header.Length;

// Write state data
writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence);
writtenBytes += context.NativeHandle.GetState(ptr + writtenBytes, stateSize, sequence);
}
finally
{
Expand All @@ -64,9 +65,7 @@ internal static void SaveState(this LLamaContext context, string filename, LLama
}
}

// Truncate the file to the actual size of data that was written
using (var fileStream = new FileStream(filename, FileMode.Open))
fileStream.SetLength(writtenBytes);
Debug.Assert(totalFileSize == writtenBytes, $"Expected to write {totalFileSize} bytes, but actually wrote {writtenBytes}");
}

/// <summary>
Expand Down Expand Up @@ -105,7 +104,7 @@ internal static void LoadState(this LLamaContext context, string filename, LLama
new Span<byte>(ptr + readBytes, headerLength).CopyTo(header);
readBytes += headerLength;

context.NativeHandle.SetState(ptr + readBytes, sequence);
context.NativeHandle.SetState(ptr + readBytes, (nuint)((long)view.SafeMemoryMappedViewHandle.ByteLength - readBytes), sequence);
}
finally
{
Expand Down
73 changes: 34 additions & 39 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
Expand Down Expand Up @@ -150,21 +150,21 @@ public void SaveState(string filename)
if (File.Exists(filename))
File.Delete(filename);

// Estimate size of state to write to disk, this is always equal to or greater than the actual size
var estimatedStateSize = checked((long)NativeHandle.GetStateSize());
// Get the exact size of the state
var stateSize = NativeHandle.GetStateSize();

// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize))
using (var view = file.CreateViewAccessor(0, estimatedStateSize))
nuint writtenBytes;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, checked((long)stateSize)))
using (var view = file.CreateViewAccessor(0, checked((long)stateSize)))
{
unsafe
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
try
{
writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize);
writtenBytes = NativeHandle.GetState(ptr, stateSize);
}
finally
{
Expand All @@ -173,9 +173,7 @@ public void SaveState(string filename)
}
}

// Truncate the file to the actual size of data that was written
using (var fileStream = new FileStream(filename, FileMode.Open))
fileStream.SetLength(writtenBytes);
Debug.Assert(stateSize == writtenBytes, $"Expected to write {stateSize} bytes, but actually wrote {writtenBytes}");
}

/// <summary>
Expand All @@ -189,21 +187,21 @@ public void SaveState(string filename, LLamaSeqId sequence)
if (File.Exists(filename))
File.Delete(filename);

// Estimate size of state to write to disk, this is always equal to or greater than the actual size
var estimatedStateSize = checked((long)NativeHandle.GetStateSize(sequence));
// Get the exact size of the state
var stateSize = NativeHandle.GetStateSize(sequence);

// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize))
using (var view = file.CreateViewAccessor(0, estimatedStateSize))
nuint writtenBytes;
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, checked((long)stateSize)))
using (var view = file.CreateViewAccessor(0, checked((long)stateSize)))
{
unsafe
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
try
{
writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize, sequence);
writtenBytes = NativeHandle.GetState(ptr, stateSize, sequence);
}
finally
{
Expand All @@ -212,9 +210,7 @@ public void SaveState(string filename, LLamaSeqId sequence)
}
}

// Truncate the file to the actual size of data that was written
using (var fileStream = new FileStream(filename, FileMode.Open))
fileStream.SetLength(writtenBytes);
Debug.Assert(stateSize == writtenBytes, $"Expected to write {stateSize} bytes, but actually wrote {writtenBytes}");
}

/// <summary>
Expand All @@ -230,15 +226,14 @@ public State GetState()
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
ulong actualSize;
// Copy the state data into memory
nuint actualSize;
unsafe
{
actualSize = NativeHandle.GetState((byte*)memory, stateSize);
}

// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
Debug.Assert(actualSize == stateSize);

// Wrap memory in a "state"
var state = new State(memory, actualSize);
Expand Down Expand Up @@ -269,14 +264,13 @@ public SequenceState GetState(LLamaSeqId sequence)
try
{
// Copy the state data into memory, discover the actual size required
ulong actualSize;
nuint actualSize;
unsafe
{
actualSize = NativeHandle.GetState((byte*)memory, stateSize, sequence);
}

// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
Debug.Assert(actualSize == stateSize);

// Wrap memory in a "state"
var state = new SequenceState(memory, actualSize);
Expand Down Expand Up @@ -309,7 +303,7 @@ public void LoadState(string filename)
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
try
{
NativeHandle.SetState(ptr);
NativeHandle.SetState(ptr, (nuint)view.SafeMemoryMappedViewHandle.ByteLength);
}
finally
{
Expand All @@ -336,7 +330,7 @@ public void LoadState(string filename, LLamaSeqId sequence)
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
try
{
NativeHandle.SetState(ptr, sequence);
NativeHandle.SetState(ptr, (nuint)view.SafeMemoryMappedViewHandle.ByteLength, sequence);
}
finally
{
Expand All @@ -354,7 +348,7 @@ public void LoadState(State state)
{
unsafe
{
NativeHandle.SetState((byte*)state.DangerousGetHandle());
NativeHandle.SetState((byte*)state.DangerousGetHandle(), state.Size);
}
}

Expand All @@ -367,7 +361,7 @@ public void LoadState(SequenceState state, LLamaSeqId sequence)
{
unsafe
{
NativeHandle.SetState((byte*)state.DangerousGetHandle(), sequence);
NativeHandle.SetState((byte*)state.DangerousGetHandle(), state.Size, sequence);
}
}
#endregion
Expand All @@ -380,7 +374,8 @@ public void LoadState(SequenceState state, LLamaSeqId sequence)
public bool ShouldAddBosToken()
{
var addBos = NativeApi.llama_add_bos_token(NativeHandle.ModelHandle);
return addBos != -1 ? Convert.ToBoolean(addBos) : NativeHandle.LLamaVocabType == LLamaVocabType.SentencePiece;
//return addBos != -1 ? Convert.ToBoolean(addBos) : NativeHandle.LLamaVocabType == LLamaVocabType.SentencePiece;
return addBos;
}

#region eval overloads
Expand Down Expand Up @@ -458,13 +453,13 @@ public void Dispose()
public class State
: SafeLLamaHandleBase
{
private readonly ulong _size;
private readonly nuint _size;
/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public ulong Size => _size;
public nuint Size => _size;

internal State(IntPtr memory, ulong size)
internal State(IntPtr memory, nuint size)
: base(memory, true)
{
_size = size;
Expand Down Expand Up @@ -513,7 +508,7 @@ public void Save(Stream stream)
public static async Task<State> LoadAsync(Stream stream)
{
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));
var state = new State(memory, (nuint)stream.Length);

UnmanagedMemoryStream dest;
unsafe
Expand All @@ -533,7 +528,7 @@ public static async Task<State> LoadAsync(Stream stream)
public static State Load(Stream stream)
{
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));
var state = new State(memory, (nuint)stream.Length);

unsafe
{
Expand All @@ -551,13 +546,13 @@ public static State Load(Stream stream)
public class SequenceState
: SafeLLamaHandleBase
{
private readonly ulong _size;
private readonly nuint _size;
/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public ulong Size => _size;
public nuint Size => _size;

internal SequenceState(IntPtr memory, ulong size)
internal SequenceState(IntPtr memory, nuint size)
: base(memory, true)
{
_size = size;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>345c8c0c87a97c1595f9c8b</BinaryReleaseId>
<BinaryReleaseId>11b84eb4578864827afcf</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/LLamaModelQuantizeParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public struct LLamaModelQuantizeParams
public GGMLType output_tensor_type;

/// <summary>
/// itoken embeddings tensor type
/// token embeddings tensor type
/// </summary>
public GGMLType token_embedding_type;

Expand Down
9 changes: 6 additions & 3 deletions LLama/Native/LLamaRopeType.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
namespace LLama.Native;
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_rope_type</remarks>
public enum LLamaRopeType
{
None = -1,
Norm = 0,
NEOX = 2,
GLM = 4,
NEOX = 2,//GGML_ROPE_TYPE_NEOX,
}
3 changes: 3 additions & 0 deletions LLama/Native/LLamaVocabPreType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ internal enum LLamaVocabPreType
TEKKEN = 20,
SMOLLM = 21,
CODESHELL = 22,
BLOOM = 23,
GPT3_FINNISH = 24,
EXAONE = 25,
}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.LLava.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace LLama.Native;

public static unsafe partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Sanity check for clip &lt;-&gt; llava embed size match
Expand Down
14 changes: 4 additions & 10 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,13 @@ public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model,
static extern int internal_llama_chat_apply_template(IntPtr model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
}

/// <summary>
/// Returns -1 if unknown, 1 for true or 0 for false.
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_add_bos_token(SafeLlamaModelHandle model);
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_add_bos_token(SafeLlamaModelHandle model);

/// <summary>
/// Returns -1 if unknown, 1 for true or 0 for false.
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_add_eos_token(SafeLlamaModelHandle model);
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_add_eos_token(SafeLlamaModelHandle model);

/// <summary>
/// Print out timing information for this context
Expand Down
Loading