Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ extend-exclude = [

[default.extend-words]
# Used in a comment in SafeLLamaSamplerHandle.cs, as a prefix of "hello"
teh = "hel"
teh = "hel"
# ot is the shorthand version of llama.cpp's override-tensor parameter
ot = "ot"
5 changes: 5 additions & 0 deletions LLama.Unittest/ModelsParamsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public void SerializeRoundTripSystemTextJson()
actual.MetadataOverrides = null!;
expected.MetadataOverrides = null!;

// Same deal
Assert.True(expected.TensorBufferOverrides.SequenceEqual(actual.TensorBufferOverrides));
actual.TensorBufferOverrides = null!;
expected.TensorBufferOverrides = null!;

// Check encoding is the same
var b1 = expected.Encoding.GetBytes("Hello");
var b2 = actual.Encoding.GetBytes("Hello");
Expand Down
3 changes: 3 additions & 0 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public class ModelOptions
/// <inheritdoc />
public GPUSplitMode? SplitMode { get; set; }

/// <inheritdoc />
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();

/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

Expand Down
6 changes: 6 additions & 0 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ public interface IModelParams
/// </summary>
GPUSplitMode? SplitMode { get; }

/// <summary>
/// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors.
/// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally.
/// </summary>
List<TensorBufferOverride> TensorBufferOverrides { get; }

/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
Expand Down
36 changes: 36 additions & 0 deletions LLama/Abstractions/TensorBufferOverride.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System;

namespace LLama.Abstractions
{
/// <summary>
/// Represents a mapping between a tensor name pattern and a specific buffer type
/// </summary>
public class TensorBufferOverride
{
/// <summary>
/// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata.
/// </summary>
public string Pattern { get; set; }

/// <summary>
/// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1
/// </summary>
public string BufferType { get; set; }

/// <summary>
/// Creates a new tensor buffer override
/// </summary>
/// <param name="pattern">Pattern to match tensor names</param>
/// <param name="bufferType">Buffer type to use for matching tensors</param>
public TensorBufferOverride(string pattern, string bufferType)
{
if (string.IsNullOrEmpty(pattern))
throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern));
if (string.IsNullOrEmpty(bufferType))
throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType));

Pattern = pattern;
BufferType = bufferType;
}
}
}
3 changes: 3 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ public record ModelParams
/// <inheritdoc />
public GPUSplitMode? SplitMode { get; set; }

/// <inheritdoc />
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();

/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

Expand Down
14 changes: 14 additions & 0 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}

// Add tensor buffer overrides, if any
if (@params.TensorBufferOverrides.Count > 0)
{
var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
disposer.Add(bufferOverrideHelper);

foreach (var tensorOverride in @params.TensorBufferOverrides)
{
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
}

bufferOverrideHelper.ApplyToModelParams(ref result);
}

if (@params.MetadataOverrides.Count == 0)
{
unsafe
Expand Down
11 changes: 6 additions & 5 deletions LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ public unsafe struct LLamaModelParams
/// todo: add support for llama_model_params.devices
/// </summary>
private IntPtr devices;

// NULL-terminated list of buffer types to use for tensors that match a pattern
// actual type: llama_model_tensor_buft_override*
// todo: add support for tensor_buft_overrides
private IntPtr tensor_buft_overrides;

/// <summary>
/// NULL-terminated list of buffer types to use for tensors that match a pattern
/// actual type: llama_model_tensor_buft_override*
/// </summary>
public IntPtr tensor_buft_overrides;
Comment thread
dpmm99 marked this conversation as resolved.
Outdated

/// <summary>
/// // number of layers to store in VRAM
Expand Down
22 changes: 22 additions & 0 deletions LLama/Native/LLamaModelTensorBufferOverride.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;

namespace LLama.Native
{
/// <summary>
/// Represents a mapping between a tensor name pattern and a backend buffer type<br/>
/// Original type: llama_model_tensor_buft_override
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaModelTensorBufferOverride
{
/// <summary>
/// Tensor name pattern to match
/// </summary>
public IntPtr Pattern;

/// <summary>
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
/// </summary>
public IntPtr BufferType;
}
}
135 changes: 135 additions & 0 deletions LLama/Native/LLamaTensorBufferOverrideHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Native
{
/// <summary>
/// Helper for creating and managing tensor buffer overrides
/// </summary>
internal class LLamaTensorBufferOverrideHelper : IDisposable
{
private readonly List<IntPtr> _allocatedMemory = new();
private readonly List<LLamaModelTensorBufferOverride> _overrides = new();
private IntPtr _overrideArray = IntPtr.Zero;
private readonly Dictionary<string, IntPtr> _bufferTypeCache = new();

/// <summary>
/// Get all available buffer types
/// </summary>
/// <returns>Dictionary mapping buffer type names to their handles</returns>
public Dictionary<string, IntPtr> GetAvailableBufferTypes()
{
var result = new Dictionary<string, IntPtr>();

nuint count = NativeApi.ggml_backend_dev_count();
for (nuint i = 0; i < count; i++)
{
IntPtr dev = NativeApi.ggml_backend_dev_get(i);
IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);

if (buft != IntPtr.Zero)
{
IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;

if (!string.IsNullOrEmpty(name))
{
result[name] = buft;
_bufferTypeCache[name] = buft;
}
}
}

return result;
}

/// <summary>
/// Add a tensor buffer override
/// </summary>
/// <param name="pattern">Tensor name pattern to match</param>
/// <param name="bufferTypeName">Name of the buffer type to use</param>
/// <returns>True if the override was added successfully</returns>
public bool AddOverride(string pattern, string bufferTypeName)
{
if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
return false;

// Get all buffer types if cache is empty
if (_bufferTypeCache.Count == 0)
{
GetAvailableBufferTypes();
}

// Check if we have this buffer type
if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
return false;

// Allocate memory for the pattern string and keep track of it
byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
_allocatedMemory.Add(patternPtr);

// Create the override
var @override = new LLamaModelTensorBufferOverride
{
Pattern = patternPtr,
BufferType = bufferType
};

_overrides.Add(@override);
return true;
}

/// <summary>
/// Apply the overrides to model parameters
/// </summary>
/// <param name="modelParams">Model parameters to update</param>
public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
{
if (_overrides.Count == 0)
{
modelParams.tensor_buft_overrides = IntPtr.Zero;
return;
}

// Free previous array if it exists
if (_overrideArray != IntPtr.Zero)
{
Marshal.FreeHGlobal(_overrideArray);
}

// Allocate memory for the array + null terminator
int size = Marshal.SizeOf<LLamaModelTensorBufferOverride>() * (_overrides.Count + 1);
_overrideArray = Marshal.AllocHGlobal(size);
_allocatedMemory.Add(_overrideArray);

// Copy overrides to array
for (int i = 0; i < _overrides.Count; i++)
{
IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
Marshal.StructureToPtr(_overrides[i], elemPtr, false);
}

// Add null terminator
IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);

// Update model params
modelParams.tensor_buft_overrides = _overrideArray;
}

/// <inheritdoc />
public void Dispose()
{
foreach (IntPtr ptr in _allocatedMemory)
{
Marshal.FreeHGlobal(ptr);
}
_allocatedMemory.Clear();
_overrides.Clear();
_overrideArray = IntPtr.Zero;
}
}
}
2 changes: 2 additions & 0 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ private static void SetDllImportResolver()

internal const string libraryName = "llama";
internal const string llavaLibraryName = "llava_shared";
internal const string ggmlLibraryName = "ggml";
internal const string ggmlBaseLibraryName = "ggml-base";

private static INativeLibrary? _loadedLLamaLibrary = null;
private static INativeLibrary? _loadedLLavaLibrary = null;
Expand Down
31 changes: 31 additions & 0 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@
internal static extern void llama_kv_self_clear(SafeLLamaContextHandle ctx);

[Obsolete("Use `llama_kv_self_clear` instead")]
/// <summary>

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

XML comment is not placed on a valid language element

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

XML comment is not placed on a valid language element

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

XML comment is not placed on a valid language element

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

XML comment is not placed on a valid language element

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

XML comment is not placed on a valid language element

Check warning on line 294 in LLama/Native/NativeApi.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

XML comment is not placed on a valid language element
/// Clear the KV cache. Both cell info is erased and KV data is zeroed
/// </summary>
/// <param name="ctx"></param>
Expand Down Expand Up @@ -447,5 +447,36 @@
// it would expose the raw pointer to the model, without properly wrapping it in a SafeLLamaModelHandle.
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//public static void llama_model* llama_get_model(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the number of available backend devices
/// </summary>
/// <returns>Count of available backend devices</returns>
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern nuint ggml_backend_dev_count();

/// <summary>
/// Get a backend device by index
/// </summary>
/// <param name="i">Device index</param>
/// <returns>Pointer to the backend device</returns>
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_dev_get(nuint i);

/// <summary>
/// Get the buffer type for a backend device
/// </summary>
/// <param name="dev">Backend device pointer</param>
/// <returns>Pointer to the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_dev_buffer_type(IntPtr dev);

/// <summary>
/// Get the name of a buffer type
/// </summary>
/// <param name="buft">Buffer type pointer</param>
/// <returns>Name of the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);
}
}
Loading