Skip to content

Commit ecf894e

Browse files
rogerbarretojeffhandley
authored andcommitted
Add OpenAIRequestPolicies extension hook for MEAI OpenAI clients (#7495)
* Add OpenAIRequestPolicies extension hook for MEAI OpenAI clients Introduces a new experimental sealed type OpenAIRequestPolicies retrievable via IChatClient.GetService<T>() / IEmbeddingGenerator.GetService<T>() on the three Microsoft.Extensions.AI OpenAI clients that go through the ToRequestOptions chokepoint (OpenAIChatClient, OpenAIEmbeddingGenerator, OpenAIResponsesChatClient). The type exposes a single AddPolicy(PipelinePolicy, PipelinePosition = PerCall) method so downstream SDKs that receive a customer-built IChatClient (and therefore cannot reconfigure the underlying OpenAIClient pipeline) can append their own pipeline policies, for example to stamp or replace the User-Agent header. Customer policies run after MEAI's internal user-agent policy, so Headers.Set replaces and Headers.Add stacks. * Address review: capture request headers instead of asserting on exception message Replaces ThrowUserAgentExceptionHandler with a CapturingUserAgentHandler that records request.Headers.UserAgent.ToString() and asserts on the captured value, addressing jozkee's feedback that Message-based assertions were too loose. Also surfaced that the runtime User-Agent is 'OpenAI/x.y.z (...) MEAI/x.y.z' (OpenAI prepends, MEAI appends), so the no-policy test now uses Contains rather than StartsWith.
1 parent 4b131b2 commit ecf894e

8 files changed

Lines changed: 385 additions & 9 deletions

File tree

src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"Name": "Microsoft.Extensions.AI.OpenAI, Version=10.5.0.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35",
2+
"Name": "Microsoft.Extensions.AI.OpenAI, Version=10.6.0.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35",
33
"Types": [
44
{
55
"Type": "static class OpenAI.Assistants.MicrosoftExtensionsAIAssistantsExtensions",
@@ -208,6 +208,20 @@
208208
"Stage": "Experimental"
209209
}
210210
]
211+
},
212+
{
213+
"Type": "sealed class Microsoft.Extensions.AI.OpenAIRequestPolicies",
214+
"Stage": "Experimental",
215+
"Methods": [
216+
{
217+
"Member": "Microsoft.Extensions.AI.OpenAIRequestPolicies.OpenAIRequestPolicies();",
218+
"Stage": "Experimental"
219+
},
220+
{
221+
"Member": "void Microsoft.Extensions.AI.OpenAIRequestPolicies.AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position = System.ClientModel.Primitives.PipelinePosition.PerCall);",
222+
"Stage": "Experimental"
223+
}
224+
]
211225
}
212226
]
213227
}

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#pragma warning disable CA1308 // Normalize strings to uppercase
2323
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
2424
#pragma warning disable SA1204 // Static elements should appear before instance elements
25+
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental
2526

2627
namespace Microsoft.Extensions.AI;
2728

@@ -55,6 +56,9 @@ internal sealed partial class OpenAIChatClient : IChatClient
5556
/// <summary>The underlying <see cref="ChatClient" />.</summary>
5657
private readonly ChatClient _chatClient;
5758

59+
/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
60+
private readonly OpenAIRequestPolicies _requestPolicies = new();
61+
5862
/// <summary>Initializes a new instance of the <see cref="OpenAIChatClient"/> class for the specified <see cref="ChatClient"/>.</summary>
5963
/// <param name="chatClient">The underlying client.</param>
6064
/// <exception cref="ArgumentNullException"><paramref name="chatClient"/> is <see langword="null"/>.</exception>
@@ -76,6 +80,7 @@ public OpenAIChatClient(ChatClient chatClient)
7680
serviceKey is not null ? null :
7781
serviceType == typeof(ChatClientMetadata) ? _metadata :
7882
serviceType == typeof(ChatClient) ? _chatClient :
83+
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
7984
serviceType.IsInstanceOfType(this) ? this :
8085
null;
8186
}
@@ -94,7 +99,7 @@ public async Task<ChatResponse> GetResponseAsync(
9499

95100
// Make the call to OpenAI.
96101
var task = _completeChatAsync is not null ?
97-
_completeChatAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
102+
_completeChatAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies)) :
98103
_chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken);
99104
var response = await task.ConfigureAwait(false);
100105

@@ -115,7 +120,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
115120

116121
// Make the call to OpenAI.
117122
var chatCompletionUpdates = _completeChatStreamingAsync is not null ?
118-
_completeChatStreamingAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
123+
_completeChatStreamingAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
119124
_chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken);
120125

121126
return FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, openAIOptions, cancellationToken);

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using OpenAI.Embeddings;
1414

1515
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
16+
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental
1617

1718
namespace Microsoft.Extensions.AI;
1819

@@ -40,6 +41,9 @@ internal sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator<string, Emb
4041
/// <summary>The number of dimensions produced by the generator.</summary>
4142
private readonly int? _dimensions;
4243

44+
/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
45+
private readonly OpenAIRequestPolicies _requestPolicies = new();
46+
4347
/// <summary>Initializes a new instance of the <see cref="OpenAIEmbeddingGenerator"/> class.</summary>
4448
/// <param name="embeddingClient">The underlying client.</param>
4549
/// <param name="defaultModelDimensions">The number of dimensions to generate in each embedding.</param>
@@ -66,7 +70,7 @@ public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerab
6670
OpenAI.Embeddings.EmbeddingGenerationOptions? openAIOptions = ToOpenAIOptions(options);
6771

6872
var t = _generateEmbeddingsAsync is not null ?
69-
_generateEmbeddingsAsync(_embeddingClient, values, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
73+
_generateEmbeddingsAsync(_embeddingClient, values, openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies)) :
7074
_embeddingClient.GenerateEmbeddingsAsync(values, openAIOptions, cancellationToken);
7175
var embeddings = (await t.ConfigureAwait(false)).Value;
7276

@@ -104,6 +108,7 @@ void IDisposable.Dispose()
104108
serviceKey is not null ? null :
105109
serviceType == typeof(EmbeddingGeneratorMetadata) ? _metadata :
106110
serviceType == typeof(EmbeddingClient) ? _embeddingClient :
111+
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
107112
serviceType.IsInstanceOfType(this) ? this :
108113
null;
109114
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.ClientModel.Primitives;
6+
using System.Diagnostics.CodeAnalysis;
7+
using System.Threading;
8+
using Microsoft.Shared.DiagnosticIds;
9+
using Microsoft.Shared.Diagnostics;
10+
11+
namespace Microsoft.Extensions.AI;
12+
13+
/// <summary>
14+
/// Provides an extension hook for adding <see cref="PipelinePolicy"/> instances to the
15+
/// <see cref="RequestOptions"/> built by Microsoft.Extensions.AI for every outbound OpenAI request
16+
/// made through the owning <c>IChatClient</c> or <c>IEmbeddingGenerator</c>.
17+
/// </summary>
18+
/// <remarks>
19+
/// <para>
20+
/// Retrieve the instance via <see cref="IChatClient.GetService(System.Type, object?)"/>
21+
/// (or the equivalent on other Microsoft.Extensions.AI client interfaces) using
22+
/// <see cref="OpenAIRequestPolicies"/> as the service type. The instance is per-client and
23+
/// reachable through any <c>ChatClientBuilder</c> decorator chain.
24+
/// </para>
25+
/// <para>
26+
/// Customer-registered policies are appended <em>after</em> Microsoft.Extensions.AI's own internal
27+
/// policies, so a policy that calls <c>message.Request.Headers.Set("User-Agent", ...)</c>
28+
/// replaces the existing value, while one that calls <c>Headers.Add(...)</c> stacks an
29+
/// additional value.
30+
/// </para>
31+
/// <para>
32+
/// Registration is intended for one-time configuration at startup, but is safe to call
33+
/// concurrently with in-flight requests.
34+
/// </para>
35+
/// </remarks>
36+
[Experimental(DiagnosticIds.Experiments.AIOpenAIRequestPolicies, UrlFormat = DiagnosticIds.UrlFormat)]
37+
public sealed class OpenAIRequestPolicies
38+
{
39+
private static readonly Entry[] _empty = Array.Empty<Entry>();
40+
41+
private Entry[] _entries = _empty;
42+
43+
/// <summary>Initializes a new instance of the <see cref="OpenAIRequestPolicies"/> class.</summary>
44+
public OpenAIRequestPolicies()
45+
{
46+
}
47+
48+
/// <summary>
49+
/// Adds a <see cref="PipelinePolicy"/> to be applied to every <see cref="RequestOptions"/>
50+
/// produced for outbound OpenAI requests by the owning Microsoft.Extensions.AI client.
51+
/// </summary>
52+
/// <param name="policy">The pipeline policy to register. Must not be <see langword="null"/>.</param>
53+
/// <param name="position">
54+
/// The position in the pipeline at which to place the policy. Defaults to
55+
/// <see cref="PipelinePosition.PerCall"/>, which runs the policy once per logical request
56+
/// (for example, to stamp a User-Agent or correlation header).
57+
/// </param>
58+
/// <exception cref="ArgumentNullException"><paramref name="policy"/> is <see langword="null"/>.</exception>
59+
public void AddPolicy(PipelinePolicy policy, PipelinePosition position = PipelinePosition.PerCall)
60+
{
61+
_ = Throw.IfNull(policy);
62+
63+
var newEntry = new Entry(policy, position);
64+
65+
// Lock-free append: copy-on-write with CAS retry.
66+
while (true)
67+
{
68+
var current = Volatile.Read(ref _entries);
69+
var updated = new Entry[current.Length + 1];
70+
Array.Copy(current, updated, current.Length);
71+
updated[current.Length] = newEntry;
72+
73+
if (Interlocked.CompareExchange(ref _entries, updated, current) == current)
74+
{
75+
return;
76+
}
77+
}
78+
}
79+
80+
/// <summary>
81+
/// Applies all registered policies to the supplied <see cref="RequestOptions"/>.
82+
/// Called by the Microsoft.Extensions.AI OpenAI clients after their own internal policies
83+
/// have been registered.
84+
/// </summary>
85+
internal void ApplyTo(RequestOptions requestOptions)
86+
{
87+
var snapshot = Volatile.Read(ref _entries);
88+
for (int i = 0; i < snapshot.Length; i++)
89+
{
90+
var entry = snapshot[i];
91+
requestOptions.AddPolicy(entry.Policy, entry.Position);
92+
}
93+
}
94+
95+
private readonly struct Entry
96+
{
97+
public Entry(PipelinePolicy policy, PipelinePosition position)
98+
{
99+
Policy = policy;
100+
Position = position;
101+
}
102+
103+
public PipelinePolicy Policy { get; }
104+
public PipelinePosition Position { get; }
105+
}
106+
}

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
2525
#pragma warning disable S3254 // Default parameter values should not be passed as arguments
2626
#pragma warning disable SA1204 // Static elements should appear before instance elements
27+
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental
2728

2829
namespace Microsoft.Extensions.AI;
2930

@@ -59,6 +60,9 @@ private static readonly Func<ResponsesClient, GetResponseOptions, RequestOptions
5960
/// <summary>The default model ID to use for the chat client.</summary>
6061
private readonly string? _defaultModelId;
6162

63+
/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
64+
private readonly OpenAIRequestPolicies _requestPolicies = new();
65+
6266
/// <summary>Initializes a new instance of the <see cref="OpenAIResponsesChatClient"/> class for the specified <see cref="ResponsesClient"/>.</summary>
6367
/// <param name="responseClient">The underlying client.</param>
6468
/// <param name="defaultModelId">The default model ID to use for the chat client.</param>
@@ -82,6 +86,7 @@ public OpenAIResponsesChatClient(ResponsesClient responseClient, string? default
8286
serviceKey is not null ? null :
8387
serviceType == typeof(ChatClientMetadata) ? _metadata :
8488
serviceType == typeof(ResponsesClient) ? _responseClient :
89+
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
8590
serviceType.IsInstanceOfType(this) ? this :
8691
null;
8792
}
@@ -100,7 +105,7 @@ public async Task<ChatResponse> GetResponseAsync(
100105
// Provided continuation token signals that an existing background response should be fetched.
101106
if (GetContinuationToken(messages, options) is { } token)
102107
{
103-
var getTask = _responseClient.GetResponseAsync(token.ResponseId, include: null, stream: null, startingAfter: null, includeObfuscation: null, cancellationToken.ToRequestOptions(streaming: false));
108+
var getTask = _responseClient.GetResponseAsync(token.ResponseId, include: null, stream: null, startingAfter: null, includeObfuscation: null, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies));
104109
var response = (ResponseResult)await getTask.ConfigureAwait(false);
105110
return FromOpenAIResponse(response, openAIOptions, openAIConversationId);
106111
}
@@ -111,7 +116,7 @@ public async Task<ChatResponse> GetResponseAsync(
111116
}
112117

113118
// Make the call to the ResponsesClient.
114-
var createTask = _responseClient.CreateResponseAsync((BinaryContent)openAIOptions, cancellationToken.ToRequestOptions(streaming: false));
119+
var createTask = _responseClient.CreateResponseAsync((BinaryContent)openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies));
115120
var openAIResponsesResult = (ResponseResult)await createTask.ConfigureAwait(false);
116121

117122
// Convert the response to a ChatResponse.
@@ -330,7 +335,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
330335

331336
Debug.Assert(_getResponseStreamingAsync is not null, $"Unable to find {nameof(_getResponseStreamingAsync)} method");
332337
IAsyncEnumerable<StreamingResponseUpdate> getUpdates = _getResponseStreamingAsync is not null ?
333-
_getResponseStreamingAsync(_responseClient, getOptions, cancellationToken.ToRequestOptions(streaming: true)) :
338+
_getResponseStreamingAsync(_responseClient, getOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
334339
_responseClient.GetResponseStreamingAsync(getOptions, cancellationToken);
335340

336341
return FromOpenAIStreamingResponseUpdatesAsync(getUpdates, openAIOptions, openAIConversationId, token.ResponseId, cancellationToken);
@@ -343,7 +348,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
343348

344349
Debug.Assert(_createResponseStreamingAsync is not null, $"Unable to find {nameof(_createResponseStreamingAsync)} method");
345350
AsyncCollectionResult<StreamingResponseUpdate> createUpdates = _createResponseStreamingAsync is not null ?
346-
_createResponseStreamingAsync(_responseClient, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
351+
_createResponseStreamingAsync(_responseClient, openAIOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
347352
_responseClient.CreateResponseStreamingAsync(openAIOptions, cancellationToken);
348353

349354
return FromOpenAIStreamingResponseUpdatesAsync(createUpdates, openAIOptions, openAIConversationId, cancellationToken: cancellationToken);

src/Libraries/Microsoft.Extensions.AI.OpenAI/RequestOptionsExtensions.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@
88
using System.Threading.Tasks;
99

1010
#pragma warning disable CA1307 // Specify StringComparison
11+
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental
1112

1213
namespace Microsoft.Extensions.AI;
1314

1415
/// <summary>Provides utility methods for creating <see cref="RequestOptions"/>.</summary>
1516
internal static class RequestOptionsExtensions
1617
{
1718
/// <summary>Creates a <see cref="RequestOptions"/> configured for use with OpenAI.</summary>
18-
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming)
19+
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming) =>
20+
ToRequestOptions(cancellationToken, streaming, policies: null);
21+
22+
/// <summary>
23+
/// Creates a <see cref="RequestOptions"/> configured for use with OpenAI, applying any
24+
/// caller-registered <see cref="OpenAIRequestPolicies"/> after Microsoft.Extensions.AI's own
25+
/// internal policies.
26+
/// </summary>
27+
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming, OpenAIRequestPolicies? policies)
1928
{
2029
RequestOptions requestOptions = new()
2130
{
@@ -25,6 +34,8 @@ public static RequestOptions ToRequestOptions(this CancellationToken cancellatio
2534

2635
requestOptions.AddPolicy(MeaiUserAgentPolicy.Instance, PipelinePosition.PerCall);
2736

37+
policies?.ApplyTo(requestOptions);
38+
2839
return requestOptions;
2940
}
3041

src/Shared/DiagnosticIds/DiagnosticIds.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ internal static class Experiments
5959
internal const string AIToolSearch = AIExperiments;
6060
internal const string AIRealTime = AIExperiments;
6161
internal const string AIFiles = AIExperiments;
62+
internal const string AIOpenAIRequestPolicies = AIExperiments;
6263

6364
// These diagnostic IDs are defined by the OpenAI package for its experimental APIs.
6465
// We use the same IDs so consumers do not need to suppress additional diagnostics

0 commit comments

Comments
 (0)