diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
index 29920409fbd..4878239f35b 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
@@ -80,6 +80,16 @@ public static partial class AIFunctionFactory
/// The handling of such parameters may be overridden via .
///
///
+ ///
+ ///
+ /// When the is constructed, it may be passed an via
+ /// . Any parameter that can be satisfied by that
+ /// according to will not be included in the generated JSON schema and will be resolved
+ /// from the provided to via ,
+ /// rather than from the argument collection. The handling of such parameters may be overridden via
+ /// .
+ ///
+ ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -168,6 +178,15 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio
/// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
///
///
+ ///
+ ///
+ /// When the is constructed, it may be passed an via
+ /// . Any parameter that can be satisfied by that
+ /// according to will not be included in the generated JSON schema and will be resolved
+ /// from the provided to via ,
+ /// rather than from the argument collection.
+ ///
+ ///
///
/// All other parameter types are bound from the dictionary passed into
/// and are included in the generated JSON schema.
@@ -260,6 +279,16 @@ public static AIFunction Create(Delegate method, string? name = null, string? de
/// The handling of such parameters may be overridden via .
///
///
+ ///
+ ///
+ /// When the is constructed, it may be passed an via
+ /// . Any parameter that can be satisfied by that
+ /// according to will not be included in the generated JSON schema and will be resolved
+ /// from the provided to via ,
+ /// rather than from the argument collection. The handling of such parameters may be overridden via
+ /// .
+ ///
+ ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -357,6 +386,15 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
/// is allowed to be ; otherwise,
/// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
///
+ ///
+ ///
+ /// When the is constructed, it may be passed an via
+ /// . Any parameter that can be satisfied by that
+ /// according to will not be included in the generated JSON schema and will be resolved
+ /// from the provided to via ,
+ /// rather than from the argument collection.
+ ///
+ ///
///
///
/// All other parameter types are bound from the dictionary passed into
@@ -465,6 +503,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
/// The handling of such parameters may be overridden via .
///
///
+ ///
+ ///
+ /// When the is constructed, it may be passed an via
+ /// . Any parameter that can be satisfied by that
+ /// according to will not be included in the generated JSON schema and will be resolved
+ /// from the provided to via ,
+ /// rather than from the argument collection. The handling of such parameters may be overridden via
+ /// .
+ ///
+ ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -661,7 +709,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
serializerOptions.MakeReadOnly();
ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions);
- DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions);
+ DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions);
if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor))
{
return descriptor;
@@ -688,6 +736,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
}
}
+ IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService();
+
// Use that binding information to impact the schema generation.
AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with
{
@@ -714,6 +764,14 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
return false;
}
+ // We assume that if the services used to create the function support a particular type,
+ // so too do the services that will be passed into InvokeAsync. This is the same basic assumption
+ // made in ASP.NET.
+ if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true)
+ {
+ return false;
+ }
+
// If there was an existing IncludeParameter delegate, now defer to it as we've
// excluded everything we need to exclude.
if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter)
@@ -735,7 +793,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
options = default;
}
- ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
+ ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService);
}
// Get a marshaling delegate for the return value.
@@ -805,7 +863,8 @@ static bool IsAsyncMethod(MethodInfo method)
private static Func GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions,
- ParameterInfo parameter)
+ ParameterInfo parameter,
+ IServiceProviderIsService? serviceProviderIsService)
{
if (string.IsNullOrWhiteSpace(parameter.Name))
{
@@ -831,28 +890,28 @@ static bool IsAsyncMethod(MethodInfo method)
// We're now into default handling of everything else.
- // For AIFunctionArgument parameters, we bind to the arguments passed directly to InvokeAsync.
+ // For AIFunctionArgument parameters, we bind to the arguments passed to InvokeAsync.
if (parameterType == typeof(AIFunctionArguments))
{
return static (arguments, _) => arguments;
}
- // For IServiceProvider parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
+ // For IServiceProvider parameters, we bind to the services passed to InvokeAsync via AIFunctionArguments.
if (parameterType == typeof(IServiceProvider))
{
return (arguments, _) =>
{
IServiceProvider? services = arguments.Services;
- if (services is null && !parameter.HasDefaultValue)
+ if (!parameter.HasDefaultValue && services is null)
{
- Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
+ ThrowNullServices(parameter.Name);
}
return services;
};
}
- // For [FromKeyedServices] parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
+ // For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments.
if (parameter.GetCustomAttribute(inherit: true) is { } keyedAttr)
{
return (arguments, _) =>
@@ -864,7 +923,38 @@ static bool IsAsyncMethod(MethodInfo method)
if (!parameter.HasDefaultValue)
{
- Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found.");
+ if (arguments.Services is null)
+ {
+ ThrowNullServices(parameter.Name);
+ }
+
+ Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'.");
+ }
+
+ return parameter.DefaultValue;
+ };
+ }
+
+ // For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync
+ // via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether
+ // the parameter should be included in the schema.
+ if (serviceProviderIsService?.IsService(parameterType) is true)
+ {
+ return (arguments, _) =>
+ {
+ if (arguments.Services?.GetService(parameterType) is { } service)
+ {
+ return service;
+ }
+
+ if (!parameter.HasDefaultValue)
+ {
+ if (arguments.Services is null)
+ {
+ ThrowNullServices(parameter.Name);
+ }
+
+ Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'.");
}
return parameter.DefaultValue;
@@ -873,7 +963,7 @@ static bool IsAsyncMethod(MethodInfo method)
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
- JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
+ JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType);
return (arguments, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
@@ -907,12 +997,16 @@ static bool IsAsyncMethod(MethodInfo method)
// If the parameter is required and there's no argument specified for it, throw.
if (!parameter.HasDefaultValue)
{
- Throw.ArgumentException(nameof(arguments), $"Missing required parameter '{parameter.Name}' for method '{parameter.Member.Name}'.");
+ Throw.ArgumentException(nameof(arguments), $"The arguments dictionary is missing a value for the required parameter '{parameter.Name}'.");
}
// Otherwise, use the optional parameter's default value.
return parameter.DefaultValue;
};
+
+ // Throws an ArgumentNullException indicating that AIFunctionArguments.Services must be provided.
+ static void ThrowNullServices(string parameterName) =>
+ Throw.ArgumentNullException($"arguments.{nameof(AIFunctionArguments.Services)}", $"Services are required for parameter '{parameterName}'.");
}
///
@@ -1075,6 +1169,7 @@ private record struct DescriptorKey(
string? Description,
Func? GetBindParameterOptions,
Func