-
Notifications
You must be signed in to change notification settings - Fork 490
Add Microsoft.Extensions.AI support for IChatClient / IEmbeddingGenerator #964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Linq; | ||
| using System.Runtime.CompilerServices; | ||
| using System.Text; | ||
| using System.Threading; | ||
| using System.Threading.Tasks; | ||
| using LLama.Common; | ||
| using LLama.Sampling; | ||
| using Microsoft.Extensions.AI; | ||
|
|
||
| namespace LLama.Abstractions; | ||
|
|
||
| /// <summary> | ||
| /// Extension methods to the <see cref="LLamaExecutorExtensions" /> interface. | ||
| /// </summary> | ||
| public static class LLamaExecutorExtensions | ||
| { | ||
| /// <summary>Gets an <see cref="IChatClient"/> instance for the specified <see cref="ILLamaExecutor"/>.</summary> | ||
| /// <param name="executor">The executor.</param> | ||
| /// <param name="historyTransform">The <see cref="IHistoryTransform"/> to use to transform an input list messages into a prompt.</param> | ||
| /// <param name="outputTransform">The <see cref="ITextStreamTransform"/> to use to transform the output into text.</param> | ||
| /// <returns>An <see cref="IChatClient"/> instance for the provided <see cref="ILLamaExecutor" />.</returns> | ||
| /// <exception cref="ArgumentNullException"><paramref name="executor"/> is null.</exception> | ||
| public static IChatClient AsChatClient( | ||
| this ILLamaExecutor executor, | ||
| IHistoryTransform? historyTransform = null, | ||
| ITextStreamTransform? outputTransform = null) => | ||
| new LLamaExecutorChatClient( | ||
| executor ?? throw new ArgumentNullException(nameof(executor)), | ||
| historyTransform, | ||
| outputTransform); | ||
|
|
||
| private sealed class LLamaExecutorChatClient( | ||
| ILLamaExecutor executor, | ||
| IHistoryTransform? historyTransform = null, | ||
| ITextStreamTransform? outputTransform = null) : IChatClient | ||
| { | ||
| private static readonly InferenceParams s_defaultParams = new(); | ||
| private static readonly DefaultSamplingPipeline s_defaultPipeline = new(); | ||
| private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"]; | ||
| [ThreadStatic] | ||
| private static Random? t_random; | ||
|
|
||
| private readonly ILLamaExecutor _executor = executor; | ||
| private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform(); | ||
| private readonly ITextStreamTransform _outputTransform = outputTransform ?? | ||
| new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts); | ||
|
|
||
| /// <inheritdoc/> | ||
| public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient)); | ||
|
|
||
| /// <inheritdoc/> | ||
| public void Dispose() { } | ||
|
|
||
| /// <inheritdoc/> | ||
| public TService? GetService<TService>(object? key = null) where TService : class => | ||
| typeof(TService) == typeof(ILLamaExecutor) ? (TService)_executor : | ||
| this as TService; | ||
|
|
||
| /// <inheritdoc/> | ||
| public async Task<ChatCompletion> CompleteAsync( | ||
| IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) | ||
| { | ||
| var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); | ||
|
|
||
| StringBuilder text = new(); | ||
| await foreach (var token in _outputTransform.TransformAsync(result)) | ||
| { | ||
| text.Append(token); | ||
| } | ||
|
|
||
| return new(new ChatMessage(ChatRole.Assistant, text.ToString())) | ||
| { | ||
| CreatedAt = DateTime.UtcNow, | ||
| }; | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync( | ||
| IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
| { | ||
| var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); | ||
|
|
||
| await foreach (var token in _outputTransform.TransformAsync(result)) | ||
| { | ||
| yield return new() | ||
| { | ||
| CreatedAt = DateTime.UtcNow, | ||
| Role = ChatRole.Assistant, | ||
| Text = token, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| /// <summary>Format the chat messages into a string prompt.</summary> | ||
| private string CreatePrompt(IList<ChatMessage> messages) | ||
| { | ||
| if (messages is null) | ||
| { | ||
| throw new ArgumentNullException(nameof(messages)); | ||
| } | ||
|
|
||
| ChatHistory history = new(); | ||
|
|
||
| if (_executor is not StatefulExecutorBase seb || | ||
| seb.GetStateData() is InteractiveExecutor.InteractiveExecutorState { IsPromptRun: true }) | ||
| { | ||
| foreach (var message in messages) | ||
| { | ||
| history.AddMessage( | ||
| message.Role == ChatRole.System ? AuthorRole.System : | ||
| message.Role == ChatRole.Assistant ? AuthorRole.Assistant : | ||
| AuthorRole.User, | ||
| string.Concat(message.Contents.OfType<TextContent>())); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| // Stateless executor with IsPromptRun = false: use only the last message. | ||
| history.AddMessage(AuthorRole.User, string.Concat(messages.LastOrDefault()?.Contents.OfType<TextContent>() ?? [])); | ||
| } | ||
|
|
||
| return _historyTransform.HistoryToText(history); | ||
| } | ||
|
|
||
| /// <summary>Convert the chat options to inference parameters.</summary> | ||
| private static InferenceParams? CreateInferenceParams(ChatOptions? options) | ||
| { | ||
| List<string> antiPrompts = new(s_antiPrompts); | ||
| if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList<string>? anti) is true) | ||
| { | ||
| antiPrompts.AddRange(anti); | ||
| } | ||
|
|
||
| return new() | ||
| { | ||
| AntiPrompts = antiPrompts, | ||
| TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep, | ||
| MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit | ||
| SamplingPipeline = new DefaultSamplingPipeline() | ||
| { | ||
| AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency, | ||
| AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence, | ||
| PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool pe) is true ? pe : s_defaultPipeline.PenalizeEOS, | ||
| PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pn) is true ? pn : s_defaultPipeline.PenalizeNewline, | ||
| RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty, | ||
| RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount, | ||
| Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar, | ||
| MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep, | ||
| MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP, | ||
| Seed = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Seed), out uint seed) is true ? seed : (uint)(t_random ??= new()).Next(), | ||
| TailFreeZ = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TailFreeZ), out float tfz) is true ? tfz : s_defaultPipeline.TailFreeZ, | ||
| Temperature = options?.Temperature ?? 0, | ||
| TopP = options?.TopP ?? 0, | ||
| TopK = options?.TopK ?? s_defaultPipeline.TopK, | ||
| TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP, | ||
| }, | ||
| }; | ||
| } | ||
|
|
||
| /// <summary>A default transform that appends "Assistant: " to the end.</summary> | ||
| private sealed class AppendAssistantHistoryTransform : LLamaTransforms.DefaultHistoryTransform | ||
| { | ||
| public override string HistoryToText(ChatHistory history) => | ||
| $"{base.HistoryToText(history)}{AuthorRole.Assistant}: "; | ||
| } | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Diagnostics; | ||
| using System.Threading; | ||
| using System.Threading.Tasks; | ||
| using LLama.Native; | ||
| using Microsoft.Extensions.AI; | ||
|
|
||
| namespace LLama; | ||
|
|
||
| public partial class LLamaEmbedder | ||
| : IEmbeddingGenerator<string, Embedding<float>> | ||
| { | ||
| private EmbeddingGeneratorMetadata? _metadata; | ||
|
|
||
| /// <inheritdoc /> | ||
| EmbeddingGeneratorMetadata IEmbeddingGenerator<string, Embedding<float>>.Metadata => | ||
| _metadata ??= new( | ||
| nameof(LLamaEmbedder), | ||
| modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, | ||
| dimensions: EmbeddingSize); | ||
|
|
||
| /// <inheritdoc /> | ||
| TService? IEmbeddingGenerator<string, Embedding<float>>.GetService<TService>(object? key) where TService : class => | ||
| typeof(TService) == typeof(LLamaContext) ? (TService)(object)Context : | ||
| this as TService; | ||
|
|
||
| /// <inheritdoc /> | ||
| async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) | ||
| { | ||
| if (Context.NativeHandle.PoolingType == LLamaPoolingType.None) | ||
| { | ||
| throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}."); | ||
| } | ||
|
|
||
| GeneratedEmbeddings<Embedding<float>> results = new() | ||
| { | ||
| Usage = new() { InputTokenCount = 0 }, | ||
| }; | ||
|
|
||
| foreach (var value in values) | ||
| { | ||
| var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false); | ||
| Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled."); | ||
|
|
||
| results.Usage.InputTokenCount += tokenCount; | ||
| results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow }); | ||
| } | ||
|
|
||
| results.Usage.TotalTokenCount = results.Usage.InputTokenCount; | ||
|
|
||
| return results; | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.