Skip to content

Commit b24581d

Browse files
committed
Added update and defrag methods for KV cache in SafeLLamaContextHandle
1 parent c933206 commit b24581d

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

LLama/Native/NativeApi.cs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -374,23 +374,6 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
374374
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
375375
public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq);
376376

377-
/// <summary>
378-
/// Defragment the KV cache. This will be applied:
379-
/// - lazily on next llama_decode()
380-
/// - explicitly with llama_kv_cache_update()
381-
/// </summary>
382-
/// <param name="ctx"></param>
383-
/// <returns></returns>
384-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
385-
public static extern LLamaPos llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
386-
387-
/// <summary>
388-
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
389-
/// </summary>
390-
/// <param name="ctx"></param>
391-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
392-
public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
393-
394377
/// <summary>
395378
/// Allocates a batch of tokens on the heap
396379
/// Each token can be assigned up to n_seq_max sequence ids

LLama/Native/SafeLLamaContextHandle.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,23 @@ static SafeLLamaContextHandle()
264264
/// </returns>
265265
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
266266
private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, LLamaSeqId dest_seq_id);
267+
268+
/// <summary>
269+
/// Defragment the KV cache. This will be applied:
270+
/// - lazily on next llama_decode()
271+
/// - explicitly with llama_kv_cache_update()
272+
/// </summary>
273+
/// <param name="ctx"></param>
274+
/// <returns></returns>
275+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
276+
private static extern void llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
277+
278+
/// <summary>
279+
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
280+
/// </summary>
281+
/// <param name="ctx"></param>
282+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
283+
public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
267284
#endregion
268285

269286
/// <summary>
@@ -487,6 +504,25 @@ public void SetThreads(uint threads, uint threadsBatch)
487504
}
488505

489506
#region KV Cache Management
507+
/// <summary>
508+
/// Apply KV cache updates (such as K-shifts, defragmentation, etc.)
509+
/// </summary>
510+
public void KvCacheUpdate()
511+
{
512+
llama_kv_cache_update(this);
513+
}
514+
515+
/// <summary>
516+
/// Defragment the KV cache. This will be applied:
517+
/// - lazily on next llama_decode()
518+
/// - explicitly with llama_kv_cache_update()
519+
/// </summary>
520+
/// <returns></returns>
521+
public void KvCacheDefrag()
522+
{
523+
llama_kv_cache_defrag(this);
524+
}
525+
490526
/// <summary>
491527
/// Get a new KV cache view that can be used to debug the KV cache
492528
/// </summary>

0 commit comments

Comments
 (0)