Skip to content

Commit eb49dd8

Browse files
committed
Added sampling tests
1 parent 4436f7e commit eb49dd8

1 file changed

Lines changed: 196 additions & 0 deletions

File tree

LLama.Unittest/SamplingTests.cs

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
using LLama.Common;
2+
using LLama.Native;
3+
4+
using System.Numerics.Tensors;
5+
using System.Runtime.InteropServices;
6+
using System.Text;
7+
8+
using Xunit.Abstractions;
9+
10+
namespace LLama.Unittest
11+
{
12+
public class SamplingTests : IDisposable
13+
{
14+
private readonly ITestOutputHelper _testOutputHelper;
15+
private readonly LLamaWeights _model;
16+
private readonly ModelParams _params;
17+
18+
private readonly LLamaBatch _batch;
19+
private readonly StreamingTokenDecoder _decoder;
20+
21+
public void Dispose() => _model.Dispose();
22+
23+
private unsafe Span<float> GetLogits(LLamaContext context, int totalSequences) => new(llama_get_logits(context.NativeHandle), totalSequences * _model.VocabCount);
24+
[DllImport("llama", CallingConvention = CallingConvention.Cdecl)] public unsafe static extern float* llama_get_logits(SafeLLamaContextHandle ctx);
25+
26+
public SamplingTests(ITestOutputHelper testOutputHelper)
27+
{
28+
_testOutputHelper = testOutputHelper;
29+
_params = new ModelParams(Constants.GenerativeModelPath) {
30+
ContextSize = 200,
31+
BatchSize = 2,
32+
GpuLayerCount = Constants.CIGpuLayerCount,
33+
};
34+
_model = LLamaWeights.LoadFromFile(_params);
35+
_batch = new LLamaBatch();
36+
_decoder = new(Encoding.UTF8, _model);
37+
}
38+
39+
40+
[Fact]
41+
public void Sampling()
42+
{
43+
using var context = new LLamaContext(_model, _params);
44+
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
45+
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();
46+
47+
// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
48+
for (int i = 0; i < tokens.Length; i++) { _batch.Add(token: tokens[i], pos: i, sequence: LLamaSeqId.Zero, logits: false); }
49+
for (int i = 0; i < 2; i++) { _batch.Add(token: tokens[i], pos: tokens.Length + i, sequence: LLamaSeqId.Zero, logits: false); }
50+
51+
// Add " repeat" and test whether next tokens will be "this phrase forever.".
52+
for (int i = 0; i < 4; i++)
53+
{
54+
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: LLamaSeqId.Zero, logits: true);
55+
DecodeAndClear(context);
56+
57+
var expected = tokens[i + 3];
58+
var logits = GetLogits(context, totalSequences: 1);
59+
60+
// Test raw sampling
61+
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));
62+
63+
// Test native sampling with `LLamaTokenDataArrayNative`.
64+
var array = LLamaTokenDataArray.Create(logits);
65+
{
66+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
67+
var rawLogits = new float[_model.VocabCount];
68+
for (int j = 0; j < cur_p.Data.Length; j++)
69+
{
70+
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
71+
}
72+
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
73+
}
74+
75+
// Test sampling chain
76+
{
77+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
78+
using var chain = CreateChain(context.NativeHandle);
79+
chain.Apply(ref cur_p);
80+
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
81+
}
82+
83+
// Test logit bias
84+
{
85+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
86+
using var chain = CreateChain(context.NativeHandle, logitBias);
87+
chain.Apply(ref cur_p);
88+
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
89+
}
90+
}
91+
}
92+
93+
94+
[Fact]
95+
public void BatchedSampling()
96+
{
97+
const int batch_count = 4;
98+
using var context = new LLamaContext(_model, _params);
99+
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
100+
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();
101+
102+
// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
103+
for (int i = 0; i < tokens.Length; i++)
104+
{
105+
for (int b = 0; b < batch_count; b++)
106+
{
107+
_batch.Add(token: tokens[i], pos: i, sequence: (LLamaSeqId) b, logits: false);
108+
}
109+
}
110+
for (int i = 0; i < 2; i++)
111+
{
112+
for (int b = 0; b < batch_count; b++)
113+
{
114+
_batch.Add(token: tokens[i], pos: tokens.Length + i, sequence: (LLamaSeqId) b, logits: false);
115+
}
116+
}
117+
118+
// Add " repeat" and test whether next tokens will be "this phrase forever.".
119+
for (int i = 0; i < 4; i++)
120+
{
121+
for (int b = 0; b < batch_count; b++)
122+
{
123+
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: (LLamaSeqId) b, logits: true);
124+
}
125+
DecodeAndClear(context);
126+
127+
var expected = tokens[i + 3];
128+
var all_logits = GetLogits(context, totalSequences: batch_count);
129+
130+
for (int b = 0; b < batch_count; b++)
131+
{
132+
var logits = all_logits.Slice(b * _model.VocabCount, _model.VocabCount);
133+
134+
// Test raw sampling
135+
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));
136+
137+
// Test native sampling with `LLamaTokenDataArrayNative`.
138+
var array = LLamaTokenDataArray.Create(logits);
139+
{
140+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
141+
var rawLogits = new float[_model.VocabCount];
142+
for (int j = 0; j < cur_p.Data.Length; j++)
143+
{
144+
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
145+
}
146+
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
147+
}
148+
149+
// Test sampling chain
150+
{
151+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
152+
using var chain = CreateChain(context.NativeHandle);
153+
chain.Apply(ref cur_p);
154+
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
155+
}
156+
157+
// Test logit bias
158+
{
159+
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
160+
using var chain = CreateChain(context.NativeHandle, logitBias);
161+
chain.Apply(ref cur_p);
162+
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
163+
}
164+
}
165+
}
166+
}
167+
168+
169+
private void DecodeAndClear(LLamaContext context)
170+
{
171+
context.Decode(_batch);
172+
_batch.Clear();
173+
}
174+
175+
private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context, LLamaLogitBias[]? logit_bias = null)
176+
{
177+
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
178+
179+
chain.AddPenalties(
180+
vocabSize: context.VocabCount,
181+
eos: context.ModelHandle.Tokens.EOS,
182+
newline: context.ModelHandle.Tokens.Newline ?? 0,
183+
penaltyCount: 60, repeat: 1, freq: 0, presence: 0,
184+
penalizeNewline: false, ignoreEOS: false
185+
);
186+
187+
if (logit_bias != null) { chain.AddLogitBias(context.VocabCount, logit_bias); }
188+
189+
chain.AddTopK(10);
190+
chain.AddTemperature(0.1f);
191+
chain.AddDistributionSampler(seed: 42);
192+
193+
return chain;
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)