Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -19,6 +20,7 @@
#pragma warning disable S1075 // URIs should not be hardcoded
#pragma warning disable SA1118 // Parameter should not span multiple lines
#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable S1067 // Expressions should not be too complex

namespace Microsoft.Extensions.AI;

Expand All @@ -41,6 +43,12 @@ public static partial class AIJsonUtilities
/// <summary>The uri used when populating the $schema keyword in inferred schemas.</summary>
private const string SchemaKeywordUri = "https://json-schema.org/draft/2020-12/schema";

/// <summary>The maximum number of schema entries to cache per JsonSerializerOptions instance.</summary>
private const int InnerCacheSoftLimit = 512;

/// <summary>A global cache for generated schemas, weakly keyed on JsonSerializerOptions instances.</summary>
private static readonly ConditionalWeakTable<JsonSerializerOptions, ConcurrentDictionary<JsonSchemaCacheKey, JsonElement>> _schemaCache = new();

// List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors.
// cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"];
Expand All @@ -65,58 +73,64 @@ public static JsonElement CreateFunctionJsonSchema(
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
title ??= method.Name;
description ??= method.GetCustomAttribute<DescriptionAttribute>()?.Description;

JsonObject parameterSchemas = new();
JsonArray? requiredProperties = null;
foreach (ParameterInfo parameter in method.GetParameters())
JsonSchemaCacheKey cacheKey = new(member: method, title, description, hasDefaultValue: false, defaultValue: null, inferenceOptions);
return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema);

static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions)
{
if (string.IsNullOrWhiteSpace(parameter.Name))
JsonObject parameterSchemas = new();
JsonArray? requiredProperties = null;
foreach (ParameterInfo parameter in ((MethodBase)key.Member!).GetParameters())
{
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
if (string.IsNullOrWhiteSpace(parameter.Name))
{
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}

JsonNode parameterSchema = CreateJsonSchemaCore(
type: parameter.ParameterType,
parameterName: parameter.Name,
description: parameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
hasDefaultValue: parameter.HasDefaultValue,
defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null,
serializerOptions,
key.Options);

parameterSchemas.Add(parameter.Name, parameterSchema);
if (!parameter.IsOptional)
{
(requiredProperties ??= []).Add((JsonNode)parameter.Name);
}
}

JsonNode parameterSchema = CreateJsonSchemaCore(
type: parameter.ParameterType,
parameterName: parameter.Name,
description: parameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
hasDefaultValue: parameter.HasDefaultValue,
defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null,
serializerOptions,
inferenceOptions);

parameterSchemas.Add(parameter.Name, parameterSchema);
if (!parameter.IsOptional)
JsonObject schema = new();
if (key.Options.IncludeSchemaKeyword)
{
(requiredProperties ??= []).Add((JsonNode)parameter.Name);
schema[SchemaPropertyName] = SchemaKeywordUri;
}
}

JsonObject schema = new();
if (inferenceOptions.IncludeSchemaKeyword)
{
schema[SchemaPropertyName] = SchemaKeywordUri;
}
if (!string.IsNullOrWhiteSpace(key.Title))
{
schema[TitlePropertyName] = key.Title;
}

if (!string.IsNullOrWhiteSpace(title))
{
schema[TitlePropertyName] = title;
}
string? description = key.Description ?? key.Member.GetCustomAttribute<DescriptionAttribute>()?.Description;
if (!string.IsNullOrWhiteSpace(description))
{
schema[DescriptionPropertyName] = description;
}

if (!string.IsNullOrWhiteSpace(description))
{
schema[DescriptionPropertyName] = description;
}
schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object".
schema[PropertiesPropertyName] = parameterSchemas;

schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object".
schema[PropertiesPropertyName] = parameterSchemas;
if (requiredProperties is not null)
{
schema[RequiredPropertyName] = requiredProperties;
}

if (requiredProperties is not null)
{
schema[RequiredPropertyName] = requiredProperties;
return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);
}

return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);
}

/// <summary>Creates a JSON schema for the specified type.</summary>
Expand All @@ -137,22 +151,19 @@ public static JsonElement CreateJsonSchema(
{
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
JsonNode schema = CreateJsonSchemaCore(type, parameterName: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions);
return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);
}

/// <summary>Gets the default JSON schema to be used by types or functions.</summary>
internal static JsonElement DefaultJsonSchema { get; } = ParseJsonElement("{}"u8);

/// <summary>Validates the provided JSON schema document.</summary>
internal static void ValidateSchemaDocument(JsonElement document, [CallerArgumentExpression("document")] string? paramName = null)
{
if (document.ValueKind is not JsonValueKind.Object or JsonValueKind.False or JsonValueKind.True)
JsonSchemaCacheKey cacheKey = new(member: type, title: null, description, hasDefaultValue, defaultValue, inferenceOptions);
return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema);
static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions)
{
Throw.ArgumentException(paramName ?? "schema", "The schema document must be an object or a boolean value.");
JsonNode schema = CreateJsonSchemaCore((Type?)key.Member, parameterName: null, key.Description, key.HasDefaultValue, key.DefaultValue, serializerOptions, key.Options);
return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode);
}
}

/// <summary>Gets the default JSON schema to be used by types or functions.</summary>
internal static JsonElement DefaultJsonSchema { get; } = ParseJsonElement("{}"u8);

#if !NET9_0_OR_GREATER
[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access",
Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " +
Expand Down Expand Up @@ -411,6 +422,68 @@ private static int IndexOf(this JsonObject jsonObject, string key)
return -1;
}
#endif

private static JsonElement GetOrAddSchema(JsonSerializerOptions serializerOptions, JsonSchemaCacheKey cacheKey, Func<JsonSchemaCacheKey, JsonSerializerOptions, JsonElement> schemaFactory)
{
ConcurrentDictionary<JsonSchemaCacheKey, JsonElement> innerCache = _schemaCache.GetOrCreateValue(serializerOptions);
if (innerCache.TryGetValue(cacheKey, out JsonElement schema))
{
return schema;
}

if (innerCache.Count >= InnerCacheSoftLimit)
{
return schemaFactory(cacheKey, serializerOptions);
}

#if NET
return innerCache.GetOrAdd(cacheKey, schemaFactory, serializerOptions);
#else
return innerCache.GetOrAdd(cacheKey, cacheKey => schemaFactory(cacheKey, serializerOptions));
#endif
}

private readonly struct JsonSchemaCacheKey : IEquatable<JsonSchemaCacheKey>
{
public JsonSchemaCacheKey(MemberInfo? member, string? title, string? description, bool hasDefaultValue, object? defaultValue, AIJsonSchemaCreateOptions options)
{
Debug.Assert(member is Type or MethodBase or null, "Must be type or method");
Member = member;
Title = title;
Description = description;
HasDefaultValue = hasDefaultValue;
DefaultValue = defaultValue;
Options = options;
}

public MemberInfo? Member { get; }
public string? Title { get; }
public string? Description { get; }
public bool HasDefaultValue { get; }
public object? DefaultValue { get; }
public AIJsonSchemaCreateOptions Options { get; }

public override bool Equals(object? obj) => obj is JsonSchemaCacheKey key && Equals(key);
public bool Equals(JsonSchemaCacheKey other) =>
Member == other.Member &&
Title == other.Title &&
Description == other.Description &&
HasDefaultValue == other.HasDefaultValue &&
Equals(DefaultValue, other.DefaultValue) &&
Options.TransformSchemaNode == other.Options.TransformSchemaNode &&
Options.IncludeTypeInEnumSchemas == other.Options.IncludeTypeInEnumSchemas &&
Options.DisallowAdditionalProperties == other.Options.DisallowAdditionalProperties &&
Options.IncludeSchemaKeyword == other.Options.IncludeSchemaKeyword &&
Options.RequireAllProperties == other.Options.RequireAllProperties;

public override int GetHashCode() =>
(Member, Title, Description, HasDefaultValue, DefaultValue,
Options.TransformSchemaNode, Options.IncludeTypeInEnumSchemas,
Options.DisallowAdditionalProperties, Options.IncludeSchemaKeyword,
Options.RequireAllProperties)
.GetHashCode();
}

private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
Expand Down