diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs index d98689f895..f5f567dbd2 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs @@ -2,8 +2,9 @@ // This sample shows multiple middleware layers working together with Azure OpenAI: // chat client (global/per-request), agent run (PII filtering and guardrails), -// function invocation (logging and result overrides), and human-in-the-loop -// approval workflows for sensitive function calls. +// function invocation (logging and result overrides), human-in-the-loop +// approval workflows for sensitive function calls, and MessageAIContextProvider +// middleware for injecting additional context messages into the agent pipeline. using System.ComponentModel; using System.Text.RegularExpressions; @@ -96,6 +97,20 @@ static string GetDateTime() Console.WriteLine($"Per-request middleware response: {response}"); +// MessageAIContextProvider middleware that injects additional messages into the agent request. +// This allows any AIAgent (not just ChatClientAgent) to benefit from MessageAIContextProvider-based +// context enrichment. Multiple providers can be passed to Use and they are called in sequence, +// each receiving the output of the previous one. +Console.WriteLine("\n\n=== Example 5: MessageAIContextProvider middleware ==="); + +var contextProviderAgent = originalAgent + .AsBuilder() + .Use([new DateTimeContextProvider()]) + .Build(); + +var contextResponse = await contextProviderAgent.RunAsync("Is it almost time for lunch?"); +Console.WriteLine($"Context-enriched response: {contextResponse}"); + // Function invocation middleware that logs before and after function calls. async ValueTask FunctionCallMiddleware(AIAgent agent, FunctionInvocationContext context, Func> next, CancellationToken cancellationToken) { @@ -259,3 +274,23 @@ async Task PerRequestChatClientMiddleware(IEnumerable return response; } + +/// +/// A that injects the current date and time into the agent's context. +/// This is a simple example of how to use a MessageAIContextProvider to enrich agent messages +/// via the extension method. +/// +internal sealed class DateTimeContextProvider : MessageAIContextProvider +{ + protected override ValueTask> ProvideMessagesAsync( + InvokingContext context, + CancellationToken cancellationToken = default) + { + Console.WriteLine("DateTimeContextProvider - Injecting current date/time context"); + + return new ValueTask>( + [ + new ChatMessage(ChatRole.User, $"For reference, the current date and time is: {DateTimeOffset.Now}") + ]); + } +} diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/README.md b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/README.md index d9433d6230..d7193b6982 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/README.md +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/README.md @@ -14,6 +14,7 @@ This sample demonstrates how to add middleware to intercept: 5. Per‑request chat client middleware 6. Per‑request function pipeline with approval 7. Combining agent‑level and per‑request middleware +8. MessageAIContextProvider middleware via `AIAgentBuilder.Use(...)` for injecting additional context messages ## Function Invocation Middleware diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 5ff7a51c38..a341abe8cd 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -146,11 +146,11 @@ private static void AddTodoItem(AgentSession? session, string item) } /// - /// An which searches for upcoming calendar events and adds them to the AI context. + /// A which searches for upcoming calendar events and adds them to the AI context. /// - internal sealed class CalendarSearchAIContextProvider(Func> loadNextThreeCalendarEvents) : AIContextProvider + internal sealed class CalendarSearchAIContextProvider(Func> loadNextThreeCalendarEvents) : MessageAIContextProvider { - protected override async ValueTask ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) { var events = await loadNextThreeCalendarEvents(); @@ -161,10 +161,7 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont outputMessageBuilder.AppendLine($" - {calendarEvent}"); } - return new AIContext - { - Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] - }; + return [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())]; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 023f6c8e5f..7ac4eed18c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -34,9 +34,6 @@ public abstract class AIContextProvider private static IEnumerable DefaultExternalOnlyFilter(IEnumerable messages) => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External); - private readonly Func, IEnumerable> _provideInputMessageFilter; - private readonly Func, IEnumerable> _storeInputMessageFilter; - /// /// Initializes a new instance of the class. /// @@ -46,10 +43,20 @@ protected AIContextProvider( Func, IEnumerable>? provideInputMessageFilter = null, Func, IEnumerable>? storeInputMessageFilter = null) { - this._provideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter; - this._storeInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter; + this.ProvideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter; + this.StoreInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter; } + /// + /// Gets the filter function to apply to input messages before providing context via . + /// + protected Func, IEnumerable> ProvideInputMessageFilter { get; } + + /// + /// Gets the filter function to apply to request messages before storing context via . + /// + protected Func, IEnumerable> StoreInputMessageFilter { get; } + /// /// Gets the key used to store the provider state in the . /// @@ -120,7 +127,7 @@ protected virtual async ValueTask InvokingCoreAsync(InvokingContext c new AIContext { Instructions = inputContext.Instructions, - Messages = inputContext.Messages is not null ? this._provideInputMessageFilter(inputContext.Messages) : null, + Messages = inputContext.Messages is not null ? this.ProvideInputMessageFilter(inputContext.Messages) : null, Tools = inputContext.Tools }); @@ -254,7 +261,7 @@ protected virtual ValueTask InvokedCoreAsync(InvokedContext context, Cancellatio return default; } - var subContext = new InvokedContext(context.Agent, context.Session, this._storeInputMessageFilter(context.RequestMessages), context.ResponseMessages!); + var subContext = new InvokedContext(context.Agent, context.Session, this.StoreInputMessageFilter(context.RequestMessages), context.ResponseMessages!); return this.StoreAIContextAsync(subContext, cancellationToken); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs new file mode 100644 index 0000000000..24264e0e47 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Provides an abstract base class for components that enhance AI context during agent invocations by supplying additional chat messages. +/// +/// +/// +/// A message AI context provider is a component that participates in the agent invocation lifecycle by: +/// +/// Listening to changes in conversations +/// Providing additional messages to agents during invocation +/// Processing invocation results for state management or learning +/// +/// +/// +/// Context providers operate through a two-phase lifecycle: they are called at the start of invocation via +/// to provide context, and optionally called at the end of invocation via +/// to process results. +/// +/// +public abstract class MessageAIContextProvider : AIContextProvider +{ + /// + /// Initializes a new instance of the class. + /// + /// An optional filter function to apply to input messages before providing messages via . If not set, defaults to including only messages. + /// An optional filter function to apply to request messages before storing messages via . If not set, defaults to including only messages. + protected MessageAIContextProvider( + Func, IEnumerable>? provideInputMessageFilter = null, + Func, IEnumerable>? storeInputMessageFilter = null) + : base(provideInputMessageFilter, storeInputMessageFilter) + { + } + + /// + protected override async ValueTask ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default) + { + // Call ProvideMessagesAsync directly to return only additional messages. + // The base AIContextProvider.InvokingCoreAsync handles merging with the original input and stamping. + return new AIContext + { + Messages = await this.ProvideMessagesAsync( + new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []), + cancellationToken).ConfigureAwait(false) + }; + } + + /// + /// Called at the start of agent invocation to provide additional messages. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the to be used by the agent during this invocation. + /// + /// + /// Implementers can load any additional messages required at this time, such as: + /// + /// Retrieving relevant information from knowledge bases + /// Adding system instructions or prompts + /// Injecting contextual messages from conversation history + /// + /// + /// + public ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + => this.InvokingCoreAsync(Throw.IfNull(context), cancellationToken); + + /// + /// Called at the start of agent invocation to provide additional messages. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the to be used by the agent during this invocation. + /// + /// + /// Implementers can load any additional messages required at this time, such as: + /// + /// Retrieving relevant information from knowledge bases + /// Adding system instructions or prompts + /// Injecting contextual messages from conversation history + /// + /// + /// + /// The default implementation of this method filters the input messages using the configured provide-input message filter + /// (which defaults to including only messages), + /// then calls to get additional messages, + /// stamps any messages with source attribution, + /// and merges the returned messages with the original (unfiltered) input messages. + /// For most scenarios, overriding is sufficient to provide additional messages, + /// while still benefiting from the default filtering, merging and source stamping behavior. + /// However, for scenarios that require more control over message filtering, merging or source stamping, overriding this method + /// allows you to directly control the full returned for the invocation. + /// + /// + protected virtual async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var inputMessages = context.RequestMessages; + + // Create a filtered context for ProvideMessagesAsync, filtering input messages + // to exclude non-external messages (e.g. chat history, other AI context provider messages). + var filteredContext = new InvokingContext( + context.Agent, + context.Session, + this.ProvideInputMessageFilter(inputMessages)); + + var providedMessages = await this.ProvideMessagesAsync(filteredContext, cancellationToken).ConfigureAwait(false); + + // Stamp and merge provided messages. + providedMessages = providedMessages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!)); + return inputMessages.Concat(providedMessages); + } + + /// + /// When overridden in a derived class, provides additional messages to be merged with the input messages for the current invocation. + /// + /// + /// + /// This method is called from . + /// Note that can be overridden to directly control messages merging and source stamping, in which case + /// it is up to the implementer to call this method as needed to retrieve the additional messages. + /// + /// + /// In contrast with , this method only returns additional messages to be merged with the input, + /// while is responsible for returning the full merged for the invocation. + /// + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// + /// A task that represents the asynchronous operation. The task result contains an + /// with additional messages to be merged with the input messages. + /// + protected virtual ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + return new ValueTask>([]); + } + + /// + /// Contains the context information provided to . + /// + /// + /// This class provides context about the invocation before the underlying AI model is invoked, including the messages + /// that will be used. Message AI Context providers can use this information to determine what additional messages + /// should be provided for the invocation. + /// + public new sealed class InvokingContext + { + /// + /// Initializes a new instance of the class with the specified request messages. + /// + /// The agent being invoked. + /// The session associated with the agent invocation. + /// The messages to be used by the agent for this invocation. + /// or is . + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) + { + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); + } + + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + + /// + /// Gets the messages that will be used by the agent for this invocation. instances can modify + /// and return or return a new message list to add additional messages for the invocation. + /// + /// + /// A collection of instances representing the messages that will be used by the agent for this invocation. + /// + /// + /// + /// If multiple instances are used in the same invocation, each + /// will receive the messages returned by the previous allowing them to build on top of each other's context. + /// + /// + /// The first in the invocation pipeline will receive the + /// caller provided messages. + /// + /// + public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ProviderSessionState{TState}.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ProviderSessionState{TState}.cs index b88e996644..ffcec7ea11 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ProviderSessionState{TState}.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ProviderSessionState{TState}.cs @@ -2,6 +2,7 @@ using System; using System.Text.Json; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -39,8 +40,8 @@ public ProviderSessionState( string stateKey, JsonSerializerOptions? jsonSerializerOptions = null) { - this._stateInitializer = stateInitializer; - this.StateKey = stateKey; + this._stateInitializer = Throw.IfNull(stateInitializer); + this.StateKey = Throw.IfNullOrWhitespace(stateKey); this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index bc2f3fabc8..1924bc0da2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -14,7 +14,7 @@ namespace Microsoft.Agents.AI.Mem0; /// -/// Provides a Mem0 backed that persists conversation messages as memories +/// Provides a Mem0 backed that persists conversation messages as memories /// and retrieves related memories to augment the agent invocation context. /// /// @@ -22,7 +22,7 @@ namespace Microsoft.Agents.AI.Mem0; /// for new invocations using a semantic search endpoint. Retrieved memories are injected as user messages /// to the model, prefixed by a configurable context prompt. /// -public sealed class Mem0Provider : AIContextProvider +public sealed class Mem0Provider : MessageAIContextProvider { private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:"; @@ -92,7 +92,7 @@ public Mem0Provider(HttpClient httpClient, Func stateIniti }; /// - protected override async ValueTask ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) { Throw.IfNull(context); @@ -101,7 +101,7 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont string queryText = string.Join( Environment.NewLine, - (context.AIContext.Messages ?? []) + context.RequestMessages .Where(m => !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); @@ -142,12 +142,9 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont } } - return new AIContext - { - Messages = outputMessageText is not null - ? [new ChatMessage(ChatRole.User, outputMessageText)] - : null - }; + return outputMessageText is not null + ? [new ChatMessage(ChatRole.User, outputMessageText)] + : []; } catch (ArgumentException) { @@ -166,7 +163,7 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont this.SanitizeLogData(searchScope.UserId)); } - return new AIContext(); + return []; } } diff --git a/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs b/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs index 961fd7f20e..52293db621 100644 --- a/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs @@ -151,6 +151,32 @@ public AIAgentBuilder Use( return this.Use((innerAgent, _) => new AnonymousDelegatingAIAgent(innerAgent, runFunc, runStreamingFunc)); } + /// + /// Adds one or more instances to the agent pipeline, enabling message enrichment + /// for any . + /// + /// + /// The instances to invoke before and after each agent invocation. + /// Providers are called in sequence, with each receiving the output of the previous provider. + /// + /// The with the providers added, enabling method chaining. + /// is empty. + /// + /// + /// This method wraps the inner agent with a that calls each provider's + /// in sequence before the inner agent runs, + /// and calls on each provider after the inner agent completes. + /// + /// + /// This allows any to benefit from -based + /// context enrichment, not just agents that natively support instances. + /// + /// + public AIAgentBuilder Use(MessageAIContextProvider[] providers) + { + return this.Use((innerAgent, _) => new MessageAIContextProviderAgent(innerAgent, providers)); + } + /// /// Provides an empty implementation. /// diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index a93bfd2c67..c6ca35951e 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -34,7 +34,7 @@ namespace Microsoft.Agents.AI; /// injecting them automatically on each invocation. /// /// -public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable +public sealed class ChatHistoryMemoryProvider : MessageAIContextProvider, IDisposable { private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:"; private const int DefaultMaxResults = 3; @@ -119,7 +119,7 @@ public ChatHistoryMemoryProvider( public override string StateKey => this._sessionState.StateKey; /// - protected override async ValueTask ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -147,17 +147,46 @@ Task InlineSearchAsync(string userQuestion, CancellationToken ct) }; } + return new AIContext + { + Messages = await this.ProvideMessagesAsync( + new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []), + cancellationToken).ConfigureAwait(false) + }; + } + + /// + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + // This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions, + // and OnDemandFunctionCalling requires tools. + if (this._searchTime != ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke) + { + throw new InvalidOperationException($"Using the {nameof(ChatHistoryMemoryProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(ChatHistoryMemoryProviderOptions.SearchTime)} is set to {ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling}."); + } + + return base.InvokingCoreAsync(context, cancellationToken); + } + + /// + protected override async ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(context); + + var state = this._sessionState.GetOrInitializeState(context.Session); + var searchScope = state.SearchScope; + try { // Get the text from the current request messages var requestText = string.Join("\n", - (context.AIContext.Messages ?? []) + (context.RequestMessages ?? []) .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); if (string.IsNullOrWhiteSpace(requestText)) { - return new AIContext(); + return []; } // Search for relevant chat history @@ -165,13 +194,10 @@ Task InlineSearchAsync(string userQuestion, CancellationToken ct) if (string.IsNullOrWhiteSpace(contextText)) { - return new AIContext(); + return []; } - return new AIContext - { - Messages = [new ChatMessage(ChatRole.User, contextText)] - }; + return [new ChatMessage(ChatRole.User, contextText)]; } catch (Exception ex) { @@ -186,7 +212,7 @@ Task InlineSearchAsync(string userQuestion, CancellationToken ct) this.SanitizeLogData(searchScope.UserId)); } - return new AIContext(); + return []; } } diff --git a/dotnet/src/Microsoft.Agents.AI/MessageAIContextProviderAgent.cs b/dotnet/src/Microsoft.Agents.AI/MessageAIContextProviderAgent.cs new file mode 100644 index 0000000000..6453209edd --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/MessageAIContextProviderAgent.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// A delegating AI agent that enriches input messages by invoking a pipeline of instances +/// before delegating to the inner agent, and notifies those providers after the inner agent completes. +/// +internal sealed class MessageAIContextProviderAgent : DelegatingAIAgent +{ + private readonly IReadOnlyList _providers; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying agent instance that will handle the core operations. + /// The message AI context providers to invoke before and after the inner agent. + public MessageAIContextProviderAgent(AIAgent innerAgent, IReadOnlyList providers) + : base(innerAgent) + { + Throw.IfNull(providers); + Throw.IfLessThanOrEqual(providers.Count, 0, nameof(providers)); + + this._providers = providers; + } + + /// + protected override async Task RunCoreAsync( + IEnumerable messages, + AgentSession? session = null, + AgentRunOptions? options = null, + CancellationToken cancellationToken = default) + { + var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false); + + AgentResponse response; + try + { + response = await this.InnerAgent.RunAsync(enrichedMessages, session, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false); + throw; + } + + await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, response.Messages, cancellationToken).ConfigureAwait(false); + + return response; + } + + /// + protected override async IAsyncEnumerable RunCoreStreamingAsync( + IEnumerable messages, + AgentSession? session = null, + AgentRunOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false); + + List responseUpdates = []; + + IAsyncEnumerator enumerator; + try + { + enumerator = this.InnerAgent.RunStreamingAsync(enrichedMessages, session, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false); + throw; + } + + bool hasUpdates; + try + { + hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false); + throw; + } + + while (hasUpdates) + { + var update = enumerator.Current; + responseUpdates.Add(update); + yield return update; + + try + { + hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false); + throw; + } + } + + var agentResponse = responseUpdates.ToAgentResponse(); + await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, agentResponse.Messages, cancellationToken).ConfigureAwait(false); + } + + /// + /// Invokes each provider's in sequence, + /// passing the output of each as input to the next. + /// + private async Task> InvokeProvidersAsync( + IEnumerable messages, + AgentSession? session, + CancellationToken cancellationToken) + { + var currentMessages = messages; + + foreach (var provider in this._providers) + { + var context = new MessageAIContextProvider.InvokingContext(this, session, currentMessages); + currentMessages = await provider.InvokingAsync(context, cancellationToken).ConfigureAwait(false); + } + + return currentMessages; + } + + /// + /// Notifies each provider of a successful invocation. + /// + private async Task NotifyProvidersOfSuccessAsync( + AgentSession? session, + IEnumerable requestMessages, + IEnumerable responseMessages, + CancellationToken cancellationToken) + { + var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, responseMessages); + + foreach (var provider in this._providers) + { + await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Notifies each provider of a failed invocation. + /// + private async Task NotifyProvidersOfFailureAsync( + AgentSession? session, + IEnumerable requestMessages, + Exception exception, + CancellationToken cancellationToken) + { + var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, exception); + + foreach (var provider in this._providers) + { + await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false); + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index 0a15f51565..dd62b0eb9b 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -32,7 +32,7 @@ namespace Microsoft.Agents.AI; /// multi-turn context to the retrieval layer without permanently altering the conversation history. /// /// -public sealed class TextSearchProvider : AIContextProvider +public sealed class TextSearchProvider : MessageAIContextProvider { private const string DefaultPluginSearchFunctionName = "Search"; private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; @@ -91,7 +91,7 @@ public TextSearchProvider( public override string StateKey => this._sessionState.StateKey; /// - protected override async ValueTask ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default) { if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke) { @@ -102,6 +102,30 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont }; } + return new AIContext + { + Messages = await this.ProvideMessagesAsync( + new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []), + cancellationToken).ConfigureAwait(false) + }; + } + + /// + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + // This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions, + // and OnDemandFunctionCalling requires tools. + if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke) + { + throw new InvalidOperationException($"Using the {nameof(TextSearchProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(TextSearchProviderOptions.SearchTime)} is set to {TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling}."); + } + + return base.InvokingCoreAsync(context, cancellationToken); + } + + /// + protected override async ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) + { // Retrieve recent messages from the session state. var recentMessagesText = this._sessionState.GetOrInitializeState(context.Session).RecentMessagesText ?? []; @@ -109,7 +133,7 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); var requestMessagesText = - (context.AIContext.Messages ?? []) + (context.RequestMessages ?? []) .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); foreach (var messageText in recentMessagesText.Concat(requestMessagesText)) { @@ -135,7 +159,7 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont if (materialized.Count == 0) { - return new AIContext(); + return []; } // Format search results @@ -146,15 +170,12 @@ protected override async ValueTask ProvideAIContextAsync(InvokingCont this._logger.LogTrace("TextSearchProvider: Search Results\nInput:{Input}\nOutput:{MessageText}", input, formatted); } - return new AIContext - { - Messages = [new ChatMessage(ChatRole.User, formatted)] - }; + return [new ChatMessage(ChatRole.User, formatted)]; } catch (Exception ex) { this._logger?.LogError(ex, "TextSearchProvider: Failed to search for data due to error"); - return new AIContext(); + return []; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs new file mode 100644 index 0000000000..8c11de6b62 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public class MessageAIContextProviderTests +{ + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + + #region InvokingAsync Tests + + [Fact] + public async Task InvokingAsync_NullContext_ThrowsArgumentNullExceptionAsync() + { + // Arrange + var provider = new TestMessageProvider(); + + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokingAsync(null!).AsTask()); + } + + [Fact] + public async Task InvokingAsync_ReturnsInputAndProvidedMessagesAsync() + { + // Arrange + var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context message") }; + var provider = new TestMessageProvider(provideMessages: providedMessages); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "User input")]); + + // Act + var result = (await provider.InvokingAsync(context)).ToList(); + + // Assert - input messages + provided messages merged + Assert.Equal(2, result.Count); + Assert.Equal("User input", result[0].Text); + Assert.Equal("Context message", result[1].Text); + } + + [Fact] + public async Task InvokingAsync_ReturnsOnlyInputMessages_WhenNoMessagesProvidedAsync() + { + // Arrange + var provider = new DefaultMessageProvider(); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hello")]); + + // Act + var result = (await provider.InvokingAsync(context)).ToList(); + + // Assert + Assert.Single(result); + Assert.Equal("Hello", result[0].Text); + } + + [Fact] + public async Task InvokingAsync_StampsProvidedMessagesWithAIContextProviderSourceAsync() + { + // Arrange + var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") }; + var provider = new TestMessageProvider(provideMessages: providedMessages); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []); + + // Act + var result = (await provider.InvokingAsync(context)).ToList(); + + // Assert + Assert.Single(result); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, result[0].GetAgentRequestMessageSourceType()); + } + + [Fact] + public async Task InvokingAsync_FiltersInputToExternalOnlyByDefaultAsync() + { + // Arrange + var provider = new TestMessageProvider(captureFilteredContext: true); + var externalMsg = new ChatMessage(ChatRole.User, "External"); + var chatHistoryMsg = new ChatMessage(ChatRole.User, "History") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var contextProviderMsg = new ChatMessage(ChatRole.User, "ContextProvider") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg, contextProviderMsg]); + + // Act + await provider.InvokingAsync(context); + + // Assert - ProvideMessagesAsync received only External messages + Assert.NotNull(provider.LastFilteredContext); + var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList(); + Assert.Single(filteredMessages); + Assert.Equal("External", filteredMessages[0].Text); + } + + [Fact] + public async Task InvokingAsync_UsesCustomProvideInputFilterAsync() + { + // Arrange - filter that keeps all messages (not just External) + var provider = new TestMessageProvider( + captureFilteredContext: true, + provideInputMessageFilter: msgs => msgs); + var externalMsg = new ChatMessage(ChatRole.User, "External"); + var chatHistoryMsg = new ChatMessage(ChatRole.User, "History") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]); + + // Act + await provider.InvokingAsync(context); + + // Assert - ProvideMessagesAsync received ALL messages (custom filter keeps everything) + Assert.NotNull(provider.LastFilteredContext); + var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList(); + Assert.Equal(2, filteredMessages.Count); + } + + [Fact] + public async Task InvokingAsync_MergesWithOriginalUnfilteredMessagesAsync() + { + // Arrange - default filter is External-only, but the MERGED result should include + // the original unfiltered input messages plus the provided messages + var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") }; + var provider = new TestMessageProvider(provideMessages: providedMessages); + var externalMsg = new ChatMessage(ChatRole.User, "External"); + var chatHistoryMsg = new ChatMessage(ChatRole.User, "History") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]); + + // Act + var result = (await provider.InvokingAsync(context)).ToList(); + + // Assert - original 2 input messages + 1 provided message + Assert.Equal(3, result.Count); + Assert.Equal("External", result[0].Text); + Assert.Equal("History", result[1].Text); + Assert.Equal("Provided", result[2].Text); + } + + #endregion + + #region ProvideAIContextAsync Tests + + [Fact] + public async Task ProvideAIContextAsync_PreservesInstructionsAndToolsAsync() + { + // Arrange + var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context") }; + var provider = new TestMessageProvider(provideMessages: providedMessages); + var inputTool = AIFunctionFactory.Create(() => "a", "inputTool"); + var inputContext = new AIContext + { + Messages = [new ChatMessage(ChatRole.User, "Hello")], + Instructions = "Be helpful", + Tools = [inputTool] + }; + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext); + + // Act + var result = await provider.InvokingAsync(context); + + // Assert - instructions and tools are preserved + Assert.Equal("Be helpful", result.Instructions); + Assert.NotNull(result.Tools); + Assert.Single(result.Tools!); + Assert.Equal("inputTool", result.Tools!.First().Name); + + // Messages include original input + provided messages (with stamping) + var messages = result.Messages!.ToList(); + Assert.Equal(2, messages.Count); + Assert.Equal("Hello", messages[0].Text); + Assert.Equal("Context", messages[1].Text); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType()); + } + + [Fact] + public async Task ProvideAIContextAsync_PreservesNullInstructionsAndToolsAsync() + { + // Arrange + var provider = new DefaultMessageProvider(); + var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext); + + // Act + var result = await provider.InvokingAsync(context); + + // Assert + Assert.Null(result.Instructions); + Assert.Null(result.Tools); + var messages = result.Messages!.ToList(); + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); + } + + #endregion + + #region InvokingContext Tests + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Act & Assert + Assert.Throws(() => new MessageAIContextProvider.InvokingContext(null!, s_mockSession, [])); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullRequestMessages() + { + // Act & Assert + Assert.Throws(() => new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); + } + + [Fact] + public void InvokingContext_Constructor_AllowsNullSession() + { + // Act + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, null, []); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Properties_Roundtrip() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + Assert.Same(s_mockSession, context.Session); + Assert.Same(messages, context.RequestMessages); + } + + [Fact] + public void InvokingContext_RequestMessages_SetterThrowsForNull() + { + // Arrange + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []); + + // Act & Assert + Assert.Throws(() => context.RequestMessages = null!); + } + + [Fact] + public void InvokingContext_RequestMessages_SetterAcceptsValidValue() + { + // Arrange + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var newMessages = new List { new(ChatRole.User, "Updated") }; + + // Act + context.RequestMessages = newMessages; + + // Assert + Assert.Same(newMessages, context.RequestMessages); + } + + #endregion + + #region GetService Tests + + [Fact] + public void GetService_ReturnsProviderForMessageAIContextProviderType() + { + // Arrange + var provider = new TestMessageProvider(); + + // Act & Assert + Assert.Same(provider, provider.GetService(typeof(MessageAIContextProvider))); + Assert.Same(provider, provider.GetService(typeof(AIContextProvider))); + Assert.Same(provider, provider.GetService(typeof(TestMessageProvider))); + } + + #endregion + + #region Test helpers + + private sealed class TestMessageProvider : MessageAIContextProvider + { + private readonly IEnumerable? _provideMessages; + private readonly bool _captureFilteredContext; + + public InvokingContext? LastFilteredContext { get; private set; } + + public TestMessageProvider( + IEnumerable? provideMessages = null, + bool captureFilteredContext = false, + Func, IEnumerable>? provideInputMessageFilter = null, + Func, IEnumerable>? storeInputMessageFilter = null) + : base(provideInputMessageFilter, storeInputMessageFilter) + { + this._provideMessages = provideMessages; + this._captureFilteredContext = captureFilteredContext; + } + + protected override ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + if (this._captureFilteredContext) + { + this.LastFilteredContext = context; + } + + return new(this._provideMessages ?? []); + } + } + + /// + /// A provider that uses only base class defaults (no overrides of ProvideMessagesAsync). + /// + private sealed class DefaultMessageProvider : MessageAIContextProvider; + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ProviderSessionStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ProviderSessionStateTests.cs index da3992bf7f..89cf109f7e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ProviderSessionStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ProviderSessionStateTests.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; + namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// @@ -7,6 +9,56 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class ProviderSessionStateTests { + #region Constructor Tests + + [Fact] + public void Constructor_ThrowsForNullStateInitializer() + { + // Act & Assert + Assert.Throws(() => new ProviderSessionState(null!, "test-key")); + } + + [Fact] + public void Constructor_ThrowsForNullStateKey() + { + // Act & Assert + Assert.Throws(() => new ProviderSessionState(_ => new TestState(), null!)); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + public void Constructor_ThrowsForEmptyOrWhitespaceStateKey(string stateKey) + { + // Act & Assert + Assert.Throws(() => new ProviderSessionState(_ => new TestState(), stateKey)); + } + + [Fact] + public void Constructor_AcceptsNullJsonSerializerOptions() + { + // Act - should not throw + var sessionState = new ProviderSessionState(_ => new TestState(), "test-key", jsonSerializerOptions: null); + + // Assert - instance is created and functional + Assert.Equal("test-key", sessionState.StateKey); + } + + [Fact] + public void Constructor_AcceptsCustomJsonSerializerOptions() + { + // Arrange + var customOptions = new System.Text.Json.JsonSerializerOptions(); + + // Act - should not throw + var sessionState = new ProviderSessionState(_ => new TestState(), "test-key", customOptions); + + // Assert - instance is created and functional + Assert.Equal("test-key", sessionState.StateKey); + } + + #endregion + #region GetOrInitializeState Tests [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 7ad77b7df5..02e18f324e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -547,6 +547,87 @@ public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() Assert.Equal(2, memoryPosts.Count); } + #region MessageAIContextProvider.InvokingAsync Tests + + [Fact] + public async Task MessageInvokingAsync_SearchesAndReturnsMergedMessagesAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[ { \"id\": \"1\", \"memory\": \"Name is Caoimhe\", \"hash\": \"h\", \"metadata\": null, \"score\": 0.9, \"created_at\": \"2023-01-01T00:00:00Z\", \"updated_at\": null, \"user_id\": \"u\", \"app_id\": null, \"agent_id\": \"agent\", \"thread_id\": \"session\" } ]"); + var storageScope = new Mem0ProviderScope + { + ApplicationId = "app", + AgentId = "agent", + ThreadId = "session", + UserId = "user" + }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); + + var inputMsg = new ChatMessage(ChatRole.User, "What is my name?"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]); + + // Act + var messages = (await sut.InvokingAsync(context)).ToList(); + + // Assert - input message + memory message, with stamping + Assert.Equal(2, messages.Count); + Assert.Equal("What is my name?", messages[0].Text); + Assert.Contains("Name is Caoimhe", messages[1].Text); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType()); + } + + [Fact] + public async Task MessageInvokingAsync_NoMemories_ReturnsOnlyInputMessagesAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); + var storageScope = new Mem0ProviderScope + { + UserId = "user" + }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); + + var inputMsg = new ChatMessage(ChatRole.User, "Hello"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]); + + // Act + var messages = (await sut.InvokingAsync(context)).ToList(); + + // Assert + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); + } + + [Fact] + public async Task MessageInvokingAsync_DefaultFilter_ExcludesNonExternalMessagesAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); + var storageScope = new Mem0ProviderScope + { + UserId = "user" + }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); + + var externalMsg = new ChatMessage(ChatRole.User, "External question"); + var historyMsg = new ChatMessage(ChatRole.User, "History message") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [externalMsg, historyMsg]); + + // Act + await sut.InvokingAsync(context); + + // Assert - Only External message used for search query + var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post && ContainsOrdinal(r.RequestMessage.RequestUri!.AbsoluteUri, "/v1/memories/search/")); + using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody); + Assert.Equal("External question", doc.RootElement.GetProperty("query").GetString()); + } + + #endregion + private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0; public void Dispose() diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 602fe40e08..46c56fc483 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -743,7 +743,7 @@ public async Task StateBag_RoundtripRestoresMessagesAsync() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 4 }); - await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory. + await newProvider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -769,7 +769,7 @@ public async Task InvokingAsync_WithEmptyStateBag_ShouldHaveNoMessagesAsync() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 }); - await provider.InvokingAsync(new(s_mockAgent, session, new AIContext()), CancellationToken.None); + await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext()), CancellationToken.None); // Assert Assert.NotNull(capturedInput); @@ -778,6 +778,101 @@ public async Task InvokingAsync_WithEmptyStateBag_ShouldHaveNoMessagesAsync() #endregion + #region MessageAIContextProvider.InvokingAsync Tests + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync() + { + // Arrange + List results = + [ + new() { SourceName = "Doc1", Text = "Content of Doc1" } + ]; + + Task> SearchDelegateAsync(string input, CancellationToken ct) + => Task.FromResult>(results); + + var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions + { + SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke + }); + + var inputMsg = new ChatMessage(ChatRole.User, "Question?"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]); + + // Act + var messages = (await provider.InvokingAsync(context)).ToList(); + + // Assert - input message + search result message, with stamping + Assert.Equal(2, messages.Count); + Assert.Equal("Question?", messages[0].Text); + Assert.Contains("Content of Doc1", messages[1].Text); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType()); + } + + [Fact] + public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync() + { + // Arrange + var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions + { + SearchTime = TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling, + }); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); + + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokingAsync(context).AsTask()); + } + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync() + { + // Arrange + var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions + { + SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke + }); + var inputMsg = new ChatMessage(ChatRole.User, "Hello"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]); + + // Act + var messages = (await provider.InvokingAsync(context)).ToList(); + + // Assert + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); + } + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync() + { + // Arrange + string? capturedInput = null; + Task> SearchDelegateAsync(string input, CancellationToken ct) + { + capturedInput = input; + return Task.FromResult>([]); + } + + var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions + { + SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke + }); + + var externalMsg = new ChatMessage(ChatRole.User, "External message"); + var historyMsg = new ChatMessage(ChatRole.System, "From history") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]); + + // Act + await provider.InvokingAsync(context); + + // Assert - Only External message used for search query + Assert.Equal("External message", capturedInput); + } + + #endregion + private Task> NoResultSearchAsync(string input, CancellationToken ct) { return Task.FromResult>([]); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index be73260b15..ff5d709202 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -710,6 +710,147 @@ public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() #endregion + #region MessageAIContextProvider.InvokingAsync Tests + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync() + { + // Arrange + var storedItems = new List>> + { + new( + new Dictionary + { + ["MessageId"] = "msg-1", + ["Content"] = "Previous message", + ["Role"] = ChatRole.User.ToString(), + ["CreatedAt"] = "2023-01-01T00:00:00.0000000+00:00" + }, + 0.9f) + }; + + this._vectorStoreCollectionMock + .Setup(c => c.SearchAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>>(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(storedItems)); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: new ChatHistoryMemoryProviderOptions + { + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke + }); + + var inputMsg = new ChatMessage(ChatRole.User, "What was discussed?"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]); + + // Act + var messages = (await provider.InvokingAsync(context)).ToList(); + + // Assert - input message + search result message, with stamping + Assert.Equal(2, messages.Count); + Assert.Equal("What was discussed?", messages[0].Text); + Assert.Contains("Previous message", messages[1].Text); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType()); + } + + [Fact] + public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync() + { + // Arrange + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: new ChatHistoryMemoryProviderOptions + { + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling + }); + + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); + + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokingAsync(context).AsTask()); + } + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync() + { + // Arrange + this._vectorStoreCollectionMock + .Setup(c => c.SearchAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>>(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(new List>>())); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: new ChatHistoryMemoryProviderOptions + { + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke + }); + + var inputMsg = new ChatMessage(ChatRole.User, "Hello"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]); + + // Act + var messages = (await provider.InvokingAsync(context)).ToList(); + + // Assert + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); + } + + [Fact] + public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync() + { + // Arrange + string? capturedQuery = null; + this._vectorStoreCollectionMock + .Setup(c => c.SearchAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>>(), + It.IsAny())) + .Callback>, CancellationToken>((query, _, _, _) => capturedQuery = query) + .Returns(ToAsyncEnumerableAsync(new List>>())); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: new ChatHistoryMemoryProviderOptions + { + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke + }); + + var externalMsg = new ChatMessage(ChatRole.User, "External message"); + var historyMsg = new ChatMessage(ChatRole.System, "From history") + .WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src"); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]); + + // Act + await provider.InvokingAsync(context); + + // Assert - Only External message used for search query + Assert.Equal("External message", capturedQuery); + } + + #endregion + private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) { await Task.Yield(); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/MessageAIContextProviderAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/MessageAIContextProviderAgentTests.cs new file mode 100644 index 0000000000..ed2f82321d --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/MessageAIContextProviderAgentTests.cs @@ -0,0 +1,469 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests; + +/// +/// Unit tests for the class and +/// the builder extension. +/// +public class MessageAIContextProviderAgentTests +{ + private static readonly AgentSession s_mockSession = new Mock().Object; + + #region Constructor Tests + + [Fact] + public void Constructor_NullInnerAgent_ThrowsArgumentNullException() + { + // Arrange + var provider = new TestProvider(); + + // Act & Assert + Assert.Throws(() => new MessageAIContextProviderAgent(null!, [provider])); + } + + [Fact] + public void Constructor_NullProviders_ThrowsArgumentNullException() + { + // Arrange + var agent = CreateTestAgent(); + + // Act & Assert + Assert.Throws(() => new MessageAIContextProviderAgent(agent, null!)); + } + + [Fact] + public void Constructor_EmptyProviders_ThrowsArgumentOutOfRangeException() + { + // Arrange + var agent = CreateTestAgent(); + + // Act & Assert + Assert.Throws(() => new MessageAIContextProviderAgent(agent, [])); + } + + #endregion + + #region RunAsync Tests + + [Fact] + public async Task RunAsync_SingleProvider_EnrichesMessagesAndDelegatesToInnerAgentAsync() + { + // Arrange + var contextMessage = new ChatMessage(ChatRole.System, "Extra context"); + var provider = new TestProvider(provideMessages: [contextMessage]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])); + }); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act + await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert - inner agent received enriched messages (input + provider's message) + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(2, messageList.Count); + Assert.Equal("Hello", messageList[0].Text); + Assert.Contains("Extra context", messageList[1].Text); + } + + [Fact] + public async Task RunAsync_MultipleProviders_CalledInSequenceAsync() + { + // Arrange + var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]); + var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])); + }); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]); + + // Act + await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert - inner agent received messages from both providers in sequence + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(3, messageList.Count); + Assert.Equal("Hello", messageList[0].Text); + Assert.Contains("From provider 1", messageList[1].Text); + Assert.Contains("From provider 2", messageList[2].Text); + } + + [Fact] + public async Task RunAsync_SequentialProviders_EachReceivesPreviousOutputAsync() + { + // Arrange - provider 2 captures the filtered messages it receives in ProvideMessagesAsync. + // The default filter only includes External messages, so provider 1's stamped messages + // (marked as AIContextProvider) are filtered out before reaching provider 2's ProvideMessagesAsync. + // However, the full unfiltered output from provider 1 is passed to provider 2's InvokingAsync, + // and the inner agent receives the full merged output from both providers. + IEnumerable? provider2ReceivedMessages = null; + var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]); + var provider2 = new TestProvider( + provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")], + onInvoking: messages => provider2ReceivedMessages = messages.ToList()); + + var innerAgent = CreateTestAgent( + runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]))); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]); + + // Act + await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert - provider 2's ProvideMessagesAsync received only External messages (filtered) + Assert.NotNull(provider2ReceivedMessages); + var received = provider2ReceivedMessages!.ToList(); + Assert.Single(received); + Assert.Equal("Hello", received[0].Text); + } + + [Fact] + public async Task RunAsync_OnSuccess_InvokedAsyncCalledOnAllProvidersAsync() + { + // Arrange + var provider1 = new TestProvider(); + var provider2 = new TestProvider(); + var innerAgent = CreateTestAgent( + runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]))); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]); + + // Act + await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert + Assert.True(provider1.InvokedAsyncCalled); + Assert.True(provider2.InvokedAsyncCalled); + Assert.Null(provider1.LastInvokedContext!.InvokeException); + Assert.Null(provider2.LastInvokedContext!.InvokeException); + } + + [Fact] + public async Task RunAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync() + { + // Arrange + var provider = new TestProvider(); + var expectedException = new InvalidOperationException("Agent failed"); + var innerAgent = CreateTestAgent( + runFunc: (_, _, _, _) => throw expectedException); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act & Assert + await Assert.ThrowsAsync(() => + agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)); + + Assert.True(provider.InvokedAsyncCalled); + Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException); + } + + [Fact] + public async Task RunAsync_OnSuccess_InvokedContextContainsResponseMessagesAsync() + { + // Arrange + var provider = new TestProvider(); + var responseMessage = new ChatMessage(ChatRole.Assistant, "Response text"); + var innerAgent = CreateTestAgent( + runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([responseMessage]))); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act + await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert + Assert.NotNull(provider.LastInvokedContext?.ResponseMessages); + Assert.Contains(provider.LastInvokedContext!.ResponseMessages!, m => m.Text == "Response text"); + } + + #endregion + + #region RunStreamingAsync Tests + + [Fact] + public async Task RunStreamingAsync_SingleProvider_EnrichesMessagesAndStreamsAsync() + { + // Arrange + var contextMessage = new ChatMessage(ChatRole.System, "Extra context"); + var provider = new TestProvider(provideMessages: [contextMessage]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runStreamingFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return ToAsyncEnumerableAsync( + new AgentResponseUpdate(ChatRole.Assistant, "Part1"), + new AgentResponseUpdate(ChatRole.Assistant, "Part2")); + }); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act + var updates = new List(); + await foreach (var update in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)) + { + updates.Add(update); + } + + // Assert - streaming updates received + Assert.Equal(2, updates.Count); + // Assert - inner agent received enriched messages + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(2, messageList.Count); + } + + [Fact] + public async Task RunStreamingAsync_OnSuccess_InvokedAsyncCalledAfterAllUpdatesAsync() + { + // Arrange + var provider = new TestProvider(); + var innerAgent = CreateTestAgent( + runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync( + new AgentResponseUpdate(ChatRole.Assistant, "Response"))); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act - consume all updates + await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)) + { + } + + // Assert + Assert.True(provider.InvokedAsyncCalled); + Assert.Null(provider.LastInvokedContext!.InvokeException); + } + + [Fact] + public async Task RunStreamingAsync_OnSuccess_InvokedContextContainsAccumulatedResponseAsync() + { + // Arrange + var provider = new TestProvider(); + var innerAgent = CreateTestAgent( + runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync( + new AgentResponseUpdate(ChatRole.Assistant, "Hello "), + new AgentResponseUpdate(ChatRole.Assistant, "World"))); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act - consume all updates + await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)) + { + } + + // Assert - InvokedAsync received the accumulated response messages + Assert.NotNull(provider.LastInvokedContext?.ResponseMessages); + var responseMessages = provider.LastInvokedContext!.ResponseMessages!.ToList(); + Assert.True(responseMessages.Count > 0); + } + + [Fact] + public async Task RunStreamingAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync() + { + // Arrange + var provider = new TestProvider(); + var expectedException = new InvalidOperationException("Stream failed"); + var innerAgent = CreateTestAgent( + runStreamingFunc: (_, _, _, _) => throw expectedException); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider]); + + // Act & Assert + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)) + { + } + }); + + Assert.True(provider.InvokedAsyncCalled); + Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException); + } + + [Fact] + public async Task RunStreamingAsync_MultipleProviders_CalledInSequenceAsync() + { + // Arrange + var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]); + var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runStreamingFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return ToAsyncEnumerableAsync(new AgentResponseUpdate(ChatRole.Assistant, "Response")); + }); + + var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]); + + // Act + await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession)) + { + } + + // Assert + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(3, messageList.Count); + Assert.Equal("Hello", messageList[0].Text); + Assert.Contains("From provider 1", messageList[1].Text); + Assert.Contains("From provider 2", messageList[2].Text); + } + + #endregion + + #region Builder Extension Tests + + [Fact] + public async Task UseExtension_CreatesWorkingPipelineAsync() + { + // Arrange + var contextMessage = new ChatMessage(ChatRole.System, "Pipeline context"); + var provider = new TestProvider(provideMessages: [contextMessage]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])); + }); + + var pipeline = new AIAgentBuilder(innerAgent) + .Use([provider]) + .Build(); + + // Act + await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(2, messageList.Count); + Assert.Equal("Hello", messageList[0].Text); + Assert.Contains("Pipeline context", messageList[1].Text); + } + + [Fact] + public async Task UseExtension_MultipleProviders_AllAppliedAsync() + { + // Arrange + var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P1")]); + var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P2")]); + + IEnumerable? capturedMessages = null; + var innerAgent = CreateTestAgent( + runFunc: (messages, _, _, _) => + { + capturedMessages = messages; + return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])); + }); + + var pipeline = new AIAgentBuilder(innerAgent) + .Use([provider1, provider2]) + .Build(); + + // Act + await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession); + + // Assert + Assert.NotNull(capturedMessages); + var messageList = capturedMessages!.ToList(); + Assert.Equal(3, messageList.Count); + } + + #endregion + + #region Helpers + + private static TestAIAgent CreateTestAgent( + Func, AgentSession?, AgentRunOptions?, CancellationToken, Task>? runFunc = null, + Func, AgentSession?, AgentRunOptions?, CancellationToken, IAsyncEnumerable>? runStreamingFunc = null) + { + var agent = new TestAIAgent(); + if (runFunc is not null) + { + agent.RunAsyncFunc = runFunc; + } + + if (runStreamingFunc is not null) + { + agent.RunStreamingAsyncFunc = runStreamingFunc; + } + + return agent; + } + + private static async IAsyncEnumerable ToAsyncEnumerableAsync(params AgentResponseUpdate[] updates) + { + foreach (var update in updates) + { + yield return update; + } + + await Task.CompletedTask; + } + + /// + /// A test implementation of that records invocation calls. + /// + private sealed class TestProvider : MessageAIContextProvider + { + private readonly IEnumerable _provideMessages; + private readonly Action>? _onInvoking; + + public bool InvokedAsyncCalled { get; private set; } + + public InvokedContext? LastInvokedContext { get; private set; } + + public TestProvider( + IEnumerable? provideMessages = null, + Action>? onInvoking = null) + { + this._provideMessages = provideMessages ?? []; + this._onInvoking = onInvoking; + } + + protected override ValueTask> ProvideMessagesAsync( + InvokingContext context, + CancellationToken cancellationToken = default) + { + this._onInvoking?.Invoke(context.RequestMessages); + return new ValueTask>(this._provideMessages); + } + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + this.InvokedAsyncCalled = true; + this.LastInvokedContext = context; + return default; + } + } + + #endregion +}