forked from SciSharp/LLamaSharp
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathLlamaSharpTextGenerator.cs
More file actions
138 lines (123 loc) · 5.66 KB
/
LlamaSharpTextGenerator.cs
File metadata and controls
138 lines (123 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
using LLama;
using LLama.Common;
using LLama.Sampling;
using Microsoft.KernelMemory.AI;
namespace LLamaSharp.KernelMemory
{
/// <summary>
/// Provides text generation for LLamaSharp.
/// </summary>
public sealed class LlamaSharpTextGenerator
: ITextGenerator, IDisposable
{
private readonly StatelessExecutor _executor;
private readonly LLamaWeights _weights;
private readonly bool _ownsWeights;
private readonly LLamaContext _context;
private readonly bool _ownsContext;
private readonly InferenceParams? _defaultInferenceParams;
public int MaxTokenTotal { get; }
/// <summary>
/// Initializes a new instance of the <see cref="LlamaSharpTextGenerator"/> class.
/// </summary>
/// <param name="config">The configuration for LLamaSharp.</param>
public LlamaSharpTextGenerator(LLamaSharpConfig config)
{
var parameters = new ModelParams(config.ModelPath)
{
ContextSize = config.ContextSize ?? 2048,
GpuLayerCount = config.GpuLayerCount ?? 20,
MainGpu = config.MainGpu,
SplitMode = config.SplitMode
};
_weights = LLamaWeights.LoadFromFile(parameters);
_context = _weights.CreateContext(parameters);
_executor = new StatelessExecutor(_weights, parameters);
_defaultInferenceParams = config.DefaultInferenceParams;
_ownsWeights = _ownsContext = true;
MaxTokenTotal = (int)parameters.ContextSize;
}
/// <summary>
/// Initializes a new instance of the <see cref="LlamaSharpTextGenerator"/> class from reused weights, context and executor.
/// If executor is not specified, then a StatelessExecutor will be created with `context.Params`. So far only `StatelessExecutor` is expected.
/// </summary>
/// <param name="weights">A LLamaWeights object.</param>
/// <param name="context">A LLamaContext object.</param>
/// <param name="executor">An executor. Currently only StatelessExecutor is expected.</param>
/// <param name="inferenceParams">Inference parameters to use by default</param>
public LlamaSharpTextGenerator(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null, InferenceParams? inferenceParams = null)
{
_weights = weights;
_context = context;
_executor = executor ?? new StatelessExecutor(_weights, _context.Params);
_defaultInferenceParams = inferenceParams;
MaxTokenTotal = (int)_context.ContextSize;
}
/// <inheritdoc/>
public void Dispose()
{
if (_ownsWeights)
{
_weights.Dispose();
}
if (_ownsContext)
{
_context.Dispose();
}
}
/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default)
{
return _executor.InferAsync(prompt, OptionsToParams(options, _defaultInferenceParams), cancellationToken: cancellationToken);
}
private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams)
{
if (defaultParams != null)
{
return defaultParams with
{
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
}
return new InferenceParams
{
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
MaxTokens = options.MaxTokens ?? 1024,
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
}
/// <inheritdoc/>
public int CountTokens(string text) => _context.Tokenize(text, special: true).Length;
/// <summary>
/// Get the list of tokens for the input text
/// </summary>
/// <param name="text">Input string to be tokenized</param>
/// <returns>Read-only list of tokens for the input test</returns>
/// <remarks>
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
/// <see cref="CountTokens(string)"/>
public IReadOnlyList<string> GetTokens(string text)
{
/* see relevant unit tests for important implementation notes regarding unicode */
var numericTokens = _context.Tokenize(text, special: true);
var decoder = new StreamingTokenDecoder(_context);
return numericTokens
.Select(x => { decoder.Add(x); return decoder.Read(); })
.ToList();
}
}
}