diff --git a/examples/210-KM-without-builder/Program.cs b/examples/210-KM-without-builder/Program.cs index c5fa22324..888446b9b 100644 --- a/examples/210-KM-without-builder/Program.cs +++ b/examples/210-KM-without-builder/Program.cs @@ -6,6 +6,7 @@ using Microsoft.KernelMemory.AI.AzureOpenAI; using Microsoft.KernelMemory.AI.OpenAI; using Microsoft.KernelMemory.Configuration; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.DataFormats; using Microsoft.KernelMemory.DataFormats.AzureAIDocIntel; using Microsoft.KernelMemory.DataFormats.Image; @@ -73,6 +74,7 @@ public static async Task Main() LoggerFactory? loggerFactory = null; // Alternative: app.Services.GetService(); // Generic dependencies + var requestContextProvider = new RequestContextProvider(); var mimeTypeDetection = new MimeTypesDetection(); var promptProvider = new EmbeddedPromptProvider(); @@ -121,7 +123,7 @@ public static async Task Main() // Create memory instance var searchClient = new SearchClient(memoryDb, textGenerator, searchClientConfig, promptProvider, contentModeration, loggerFactory); - var memory = new MemoryServerless(orchestrator, searchClient, kernelMemoryConfig); + var memory = new MemoryServerless(orchestrator, searchClient, requestContextProvider, kernelMemoryConfig); // End-to-end test await memory.ImportTextAsync("I'm waiting for Godot", documentId: "tg01"); diff --git a/examples/212-dotnet-ollama/Program.cs b/examples/212-dotnet-ollama/Program.cs index 49ce8d6b2..c7492564c 100644 --- a/examples/212-dotnet-ollama/Program.cs +++ b/examples/212-dotnet-ollama/Program.cs @@ -3,6 +3,7 @@ using Microsoft.KernelMemory; using Microsoft.KernelMemory.AI.Ollama; using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; /* This example shows how to use KM with Ollama @@ -49,19 +50,46 @@ public static async Task Main() // Generate an answer - This uses OpenAI for embeddings and finding relevant data, and LM Studio to generate an answer var answer = await memory.AskAsync("What's the current date (don't check for validity)?"); + Console.WriteLine("-------------------"); Console.WriteLine(answer.Question); Console.WriteLine(answer.Result); + Console.WriteLine("-------------------"); /* -- Output using phi3:medium-128k: What's the current date (don't check for validity)? + The given fact states that "Today is October 32nd, 2476." However, it appears to be an incorrect statement as there are never more than 31 days in any month. If we consider this date without checking its validity and accept the stated day of October as being 32, then the current date would be "October 32nd, 2476." However, it is important to note that this date does not align with our calendar system. */ + + // How to override config with Request Context + var context = new RequestContext(); + context.SetArg("custom_text_generation_model_name", "llama2:70b"); + // context.SetArg("custom_embedding_generation_model_name", "..."); + + answer = await memory.AskAsync("What's the current date (don't check for validity)?", context: context); + Console.WriteLine("-------------------"); + Console.WriteLine(answer.Question); + Console.WriteLine(answer.Result); + Console.WriteLine("-------------------"); + + /* + + -- Output using llama2:70b: + + What's the current date (don't check for validity)? + + The provided facts state that "Today is October 32nd, 2476." However, considering the Gregorian calendar system + commonly used today, this information appears to be incorrect as there are no such dates. This could + potentially refer to a different calendar or timekeeping system in use in your fictional world, but based on our + current understanding of calendars and dates, an "October 32nd" does not exist. Therefore, the answer is + 'INFO NOT FOUND'. + */ } } diff --git a/extensions/Anthropic/AnthropicTextGeneration.cs b/extensions/Anthropic/AnthropicTextGeneration.cs index 4d8597b95..2aa60eb26 100644 --- a/extensions/Anthropic/AnthropicTextGeneration.cs +++ b/extensions/Anthropic/AnthropicTextGeneration.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.Anthropic.Client; using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; namespace Microsoft.KernelMemory.AI.Anthropic; @@ -23,6 +24,7 @@ public sealed class AnthropicTextGeneration : ITextGenerator, IDisposable private readonly RawAnthropicClient _client; private readonly ITextTokenizer _textTokenizer; + private readonly IContextProvider _contextProvider; private readonly HttpClient _httpClient; private readonly ILogger _log; private readonly string _modelName; @@ -34,11 +36,13 @@ public sealed class AnthropicTextGeneration : ITextGenerator, IDisposable /// Client configuration, including credentials and model details /// Tokenizer used to count tokens /// Optional factory used to inject a pre-configured HTTP client for requests to Anthropic API + /// Request context provider with runtime configuration overrides /// Optional factory used to inject configured loggers public AnthropicTextGeneration( AnthropicConfig config, ITextTokenizer? textTokenizer = null, IHttpClientFactory? httpClientFactory = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) { this._modelName = config.TextModelName; @@ -48,6 +52,7 @@ public AnthropicTextGeneration( this.MaxTokenTotal = config.MaxTokenOut; this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + this._contextProvider = contextProvider ?? new RequestContextProvider(); if (httpClientFactory == null) { @@ -96,9 +101,11 @@ public async IAsyncEnumerable GenerateTextAsync( TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - this._log.LogTrace("Sending text generation request, model '{0}'", this._modelName); + string modelName = this._contextProvider.GetContext().GetCustomTextGenerationModelNameOrDefault(this._modelName); - CallClaudeStreamingParams parameters = new(this._modelName, prompt) + this._log.LogTrace("Sending text generation request, model '{0}'", modelName); + + CallClaudeStreamingParams parameters = new(modelName, prompt) { System = this._defaultSystemPrompt, Temperature = options.Temperature, diff --git a/extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs b/extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs index d588c9342..4a096c4b7 100644 --- a/extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs @@ -17,6 +17,12 @@ namespace Microsoft.KernelMemory.AI.AzureOpenAI; +/// +/// Azure OpenAI connector +/// +/// Note: does not support model name override via request context +/// see https://github.com/microsoft/semantic-kernel/issues/9337 +/// [Experimental("KMEXP01")] public sealed class AzureOpenAITextEmbeddingGenerator : ITextEmbeddingGenerator, ITextEmbeddingBatchGenerator { diff --git a/extensions/AzureOpenAI/AzureOpenAITextGenerator.cs b/extensions/AzureOpenAI/AzureOpenAITextGenerator.cs index 5179804fb..cb6126901 100644 --- a/extensions/AzureOpenAI/AzureOpenAITextGenerator.cs +++ b/extensions/AzureOpenAI/AzureOpenAITextGenerator.cs @@ -15,6 +15,12 @@ namespace Microsoft.KernelMemory.AI.AzureOpenAI; +/// +/// Azure OpenAI connector +/// +/// Note: does not support model name override via request context +/// see https://github.com/microsoft/semantic-kernel/issues/9337 +/// [Experimental("KMEXP01")] public sealed class AzureOpenAITextGenerator : ITextGenerator { diff --git a/extensions/ONNX/Onnx/OnnxTextGenerator.cs b/extensions/ONNX/Onnx/OnnxTextGenerator.cs index f0dbab49c..1a51a5108 100644 --- a/extensions/ONNX/Onnx/OnnxTextGenerator.cs +++ b/extensions/ONNX/Onnx/OnnxTextGenerator.cs @@ -19,6 +19,8 @@ namespace Microsoft.KernelMemory.AI.Onnx; /// /// Text generator based on ONNX models, via OnnxRuntimeGenAi /// See https://github.com/microsoft/onnxruntime-genai +/// +/// Note: does not support model name override via request context /// [Experimental("KMEXP01")] public sealed class OnnxTextGenerator : ITextGenerator, IDisposable diff --git a/extensions/Ollama/Ollama/DependencyInjection.cs b/extensions/Ollama/Ollama/DependencyInjection.cs index 618fc3497..93c9d5667 100644 --- a/extensions/Ollama/Ollama/DependencyInjection.cs +++ b/extensions/Ollama/Ollama/DependencyInjection.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.AI.Ollama; +using Microsoft.KernelMemory.Context; using OllamaSharp; #pragma warning disable IDE0130 // reduce number of "using" statements @@ -72,6 +73,7 @@ public static IServiceCollection AddOllamaTextGeneration( new OllamaApiClient(new Uri(endpoint), modelName), new OllamaModelConfig { ModelName = modelName }, textTokenizer, + serviceProvider.GetService(), serviceProvider.GetService())); } @@ -86,6 +88,7 @@ public static IServiceCollection AddOllamaTextGeneration( new OllamaApiClient(new Uri(config.Endpoint), config.TextModel.ModelName), config.TextModel, textTokenizer, + serviceProvider.GetService(), serviceProvider.GetService())); } @@ -101,6 +104,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( new OllamaApiClient(new Uri(endpoint), modelName), new OllamaModelConfig { ModelName = modelName }, textTokenizer, + serviceProvider.GetService(), serviceProvider.GetService())); } @@ -115,6 +119,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( new OllamaApiClient(new Uri(config.Endpoint), config.EmbeddingModel.ModelName), config.EmbeddingModel, textTokenizer, + serviceProvider.GetService(), serviceProvider.GetService())); } } diff --git a/extensions/Ollama/Ollama/OllamaTextEmbeddingGenerator.cs b/extensions/Ollama/Ollama/OllamaTextEmbeddingGenerator.cs index 9719648bc..ccb708f9e 100644 --- a/extensions/Ollama/Ollama/OllamaTextEmbeddingGenerator.cs +++ b/extensions/Ollama/Ollama/OllamaTextEmbeddingGenerator.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; using OllamaSharp; using OllamaSharp.Models; @@ -20,8 +21,9 @@ public class OllamaTextEmbeddingGenerator : ITextEmbeddingGenerator, ITextEmbedd private readonly IOllamaApiClient _client; private readonly OllamaModelConfig _modelConfig; - private readonly ILogger _log; private readonly ITextTokenizer _textTokenizer; + private readonly IContextProvider _contextProvider; + private readonly ILogger _log; public int MaxTokens { get; } @@ -31,6 +33,7 @@ public OllamaTextEmbeddingGenerator( IOllamaApiClient ollamaClient, OllamaModelConfig modelConfig, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) { this._client = ollamaClient; @@ -47,6 +50,7 @@ public OllamaTextEmbeddingGenerator( } this._textTokenizer = textTokenizer; + this._contextProvider = contextProvider ?? new RequestContextProvider(); this.MaxTokens = modelConfig.MaxTokenTotal ?? MaxTokensIfUndefined; } @@ -54,11 +58,13 @@ public OllamaTextEmbeddingGenerator( public OllamaTextEmbeddingGenerator( OllamaConfig config, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) : this( new OllamaApiClient(new Uri(config.Endpoint), config.EmbeddingModel.ModelName), config.EmbeddingModel, textTokenizer, + contextProvider, loggerFactory) { } @@ -67,11 +73,13 @@ public OllamaTextEmbeddingGenerator( HttpClient httpClient, OllamaConfig config, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) : this( new OllamaApiClient(httpClient, config.EmbeddingModel.ModelName), config.EmbeddingModel, textTokenizer, + contextProvider, loggerFactory) { } @@ -104,11 +112,13 @@ public async Task GenerateEmbeddingBatchAsync( CancellationToken cancellationToken = default) { var list = textList.ToList(); - this._log.LogTrace("Generating embeddings batch, size {0} texts", list.Count); + + string modelName = this._contextProvider.GetContext().GetCustomEmbeddingGenerationModelNameOrDefault(this._client.SelectedModel); + this._log.LogTrace("Generating embeddings batch, size {0} texts, with model {1}", list.Count, modelName); var request = new EmbedRequest { - Model = this._client.SelectedModel, + Model = modelName, Input = list, Options = new RequestOptions { diff --git a/extensions/Ollama/Ollama/OllamaTextGenerator.cs b/extensions/Ollama/Ollama/OllamaTextGenerator.cs index c5bf02eb7..e213a5c95 100644 --- a/extensions/Ollama/Ollama/OllamaTextGenerator.cs +++ b/extensions/Ollama/Ollama/OllamaTextGenerator.cs @@ -7,6 +7,7 @@ using System.Threading; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; using OllamaSharp; using OllamaSharp.Models; @@ -19,8 +20,9 @@ public class OllamaTextGenerator : ITextGenerator private readonly IOllamaApiClient _client; private readonly OllamaModelConfig _modelConfig; - private readonly ILogger _log; private readonly ITextTokenizer _textTokenizer; + private readonly IContextProvider _contextProvider; + private readonly ILogger _log; public int MaxTokenTotal { get; } @@ -28,6 +30,7 @@ public OllamaTextGenerator( IOllamaApiClient ollamaClient, OllamaModelConfig modelConfig, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) { this._client = ollamaClient; @@ -43,6 +46,7 @@ public OllamaTextGenerator( } this._textTokenizer = textTokenizer; + this._contextProvider = contextProvider ?? new RequestContextProvider(); this.MaxTokenTotal = modelConfig.MaxTokenTotal ?? MaxTokensIfUndefined; } @@ -50,11 +54,13 @@ public OllamaTextGenerator( public OllamaTextGenerator( OllamaConfig config, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) : this( new OllamaApiClient(new Uri(config.Endpoint), config.TextModel.ModelName), config.TextModel, textTokenizer, + contextProvider, loggerFactory) { } @@ -63,11 +69,13 @@ public OllamaTextGenerator( HttpClient httpClient, OllamaConfig config, ITextTokenizer? textTokenizer = null, + IContextProvider? contextProvider = null, ILoggerFactory? loggerFactory = null) : this( new OllamaApiClient(httpClient, config.TextModel.ModelName), config.TextModel, textTokenizer, + contextProvider, loggerFactory) { } @@ -87,9 +95,12 @@ public async IAsyncEnumerable GenerateTextAsync( TextGenerationOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + string modelName = this._contextProvider.GetContext().GetCustomTextGenerationModelNameOrDefault(this._client.SelectedModel); + this._log.LogTrace("Generating text with model {0}", modelName); + var request = new GenerateRequest { - Model = this._client.SelectedModel, + Model = modelName, Prompt = prompt, Stream = true, Options = new RequestOptions diff --git a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs index ff6147177..9efe22f42 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs @@ -18,6 +18,9 @@ namespace Microsoft.KernelMemory.AI.OpenAI; /// /// Text embedding generator. The class can be used with any service /// supporting OpenAI HTTP schema. +/// +/// Note: does not support model name override via request context +/// see https://github.com/microsoft/semantic-kernel/issues/9337 /// [Experimental("KMEXP01")] public sealed class OpenAITextEmbeddingGenerator : ITextEmbeddingGenerator, ITextEmbeddingBatchGenerator diff --git a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs index 665f0c0f5..b51aa98b1 100644 --- a/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs +++ b/extensions/OpenAI/OpenAI/OpenAITextGenerator.cs @@ -17,6 +17,9 @@ namespace Microsoft.KernelMemory.AI.OpenAI; /// /// Text generator, supporting OpenAI text and chat completion. The class can be used with any service /// supporting OpenAI HTTP schema, such as LM Studio HTTP API. +/// +/// Note: does not support model name override via request context +/// see https://github.com/microsoft/semantic-kernel/issues/9337 /// [Experimental("KMEXP01")] public sealed class OpenAITextGenerator : ITextGenerator diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 77ea23e34..d60a0f438 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -43,6 +43,19 @@ public static class EmbeddingGeneration { // Used to override MaxBatchSize embedding generators config public const string BatchSize = "custom_embedding_generation_batch_size_int"; + + // Used to override the name of the model used to generate embeddings + // Supported only by Ollama and Anthropic connectors + // See https://github.com/microsoft/semantic-kernel/issues/9337 + public const string ModelName = "custom_embedding_generation_model_name"; + } + + public static class TextGeneration + { + // Used to override the name of the model used to generate text + // Supported only by Ollama and Anthropic connectors + // See https://github.com/microsoft/semantic-kernel/issues/9337 + public const string ModelName = "custom_text_generation_model_name"; } public static class Rag diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs index 4835fd794..8809b0e8e 100644 --- a/service/Abstractions/Context/IContext.cs +++ b/service/Abstractions/Context/IContext.cs @@ -219,4 +219,42 @@ public static int GetCustomEmbeddingGenerationBatchSizeOrDefault(this IContext? return defaultValue; } + + /// + /// Extensions supported: + /// - Ollama + /// - Anthropic + /// Extensions not supported: + /// - Azure OpenAI + /// - ONNX + /// - OpenAI + /// + public static string GetCustomTextGenerationModelNameOrDefault(this IContext? context, string defaultValue) + { + if (context.TryGetArg(Constants.CustomContext.TextGeneration.ModelName, out var customValue)) + { + return customValue; + } + + return defaultValue; + } + + /// + /// Extensions supported: + /// - Ollama + /// - Anthropic + /// Extensions not supported: + /// - Azure OpenAI + /// - ONNX + /// - OpenAI + /// + public static string GetCustomEmbeddingGenerationModelNameOrDefault(this IContext? context, string defaultValue) + { + if (context.TryGetArg(Constants.CustomContext.EmbeddingGeneration.ModelName, out var customValue)) + { + return customValue; + } + + return defaultValue; + } } diff --git a/service/Abstractions/Context/IContextProvider.cs b/service/Abstractions/Context/IContextProvider.cs index 22986fa68..1f9b38dfa 100644 --- a/service/Abstractions/Context/IContextProvider.cs +++ b/service/Abstractions/Context/IContextProvider.cs @@ -23,6 +23,14 @@ public static class ContextProviderExtensions return provider; } + public static IContextProvider? InitContext(this IContextProvider? provider, IContext? context) + { + if (provider == null) { return null; } + + provider.GetContext().InitArgs(context?.Arguments ?? new Dictionary()); + return provider; + } + public static IContextProvider? SetContextArgs(this IContextProvider? provider, IDictionary args) { if (provider == null) { return null; } diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs index 1b0be3981..91a6b741b 100644 --- a/service/Core/MemoryServerless.cs +++ b/service/Core/MemoryServerless.cs @@ -26,6 +26,7 @@ public sealed class MemoryServerless : IKernelMemory { private readonly InProcessPipelineOrchestrator _orchestrator; private readonly ISearchClient _searchClient; + private readonly IContextProvider _contextProvider; private readonly string? _defaultIndexName; /// @@ -41,10 +42,12 @@ public sealed class MemoryServerless : IKernelMemory public MemoryServerless( InProcessPipelineOrchestrator orchestrator, ISearchClient searchClient, + IContextProvider? contextProvider = null, KernelMemoryConfig? config = null) { this._orchestrator = orchestrator ?? throw new ConfigurationException("The orchestrator is NULL"); this._searchClient = searchClient ?? throw new ConfigurationException("The search client is NULL"); + this._contextProvider = contextProvider ?? new RequestContextProvider(); // A non-null config object is required in order to get a non-empty default index name config ??= new KernelMemoryConfig(); @@ -59,6 +62,7 @@ public Task ImportDocumentAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); DocumentUploadRequest uploadRequest = new(document, index, steps); return this.ImportDocumentAsync(uploadRequest, context, cancellationToken); } @@ -73,6 +77,7 @@ public Task ImportDocumentAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); var document = new Document(documentId, tags: tags).AddFile(filePath); DocumentUploadRequest uploadRequest = new(document, index, steps); return this.ImportDocumentAsync(uploadRequest, context, cancellationToken); @@ -84,6 +89,7 @@ public Task ImportDocumentAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); var index = IndexName.CleanName(uploadRequest.Index, this._defaultIndexName); return this._orchestrator.ImportDocumentAsync(index, uploadRequest, context, cancellationToken); } @@ -99,6 +105,7 @@ public Task ImportDocumentAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); var document = new Document(documentId, tags: tags).AddStream(fileName, content); DocumentUploadRequest uploadRequest = new(document, index, steps); return this.ImportDocumentAsync(uploadRequest, context, cancellationToken); @@ -114,6 +121,7 @@ public async Task ImportTextAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); var content = new MemoryStream(Encoding.UTF8.GetBytes(text)); await using (content.ConfigureAwait(false)) { @@ -139,6 +147,7 @@ public async Task ImportWebPageAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); var uri = new Uri(url); Verify.ValidateUrl(uri.AbsoluteUri, requireHttps: false, allowReservedIp: false, allowQuery: true); @@ -233,6 +242,7 @@ public Task SearchAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); if (filter != null) { if (filters == null) { filters = new List(); } @@ -261,6 +271,7 @@ public Task AskAsync( IContext? context = null, CancellationToken cancellationToken = default) { + this._contextProvider.InitContext(context); if (filter != null) { if (filters == null) { filters = new List(); }