Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,25 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
objSchema.InsertAtStart(TypePropertyName, "string");
}

// Include the type keyword in nullable enum types
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type)?.IsEnum is true && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type) is Type nullableElement)
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
// Account for bug https://github.com/dotnet/runtime/issues/117493
// null not inserted in the type keyword for root-level Nullable<T> types.
if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) &&
typeKeyWord?.GetValueKind() is JsonValueKind.String)
{
string typeValue = typeKeyWord.GetValue<string>()!;
if (typeValue is not "null")
{
objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" };
}
}

// Include the type keyword in nullable enum types
if (nullableElement.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
}
}

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
Expand Down Expand Up @@ -605,7 +620,7 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont
{
numericType = null;

if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray)
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{
bool allowString = false;

Expand All @@ -617,11 +632,23 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont
switch (type)
{
case "integer" or "number":
if (numericType is not null)
{
// Conflicting numeric type
return false;
}

numericType = type;
break;
case "string":
allowString = true;
break;
case "null":
// Nullable integer.
break;
default:
// keyword is not valid in the context of numeric types.
return false;
}
}
}
Expand Down Expand Up @@ -665,7 +692,7 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)

if (defaultValue is null || (defaultValue == DBNull.Value && parameterType != typeof(DBNull)))
{
return parameterType.IsValueType
return parameterType.IsValueType && Nullable.GetUnderlyingType(parameterType) is null
#if NET
? RuntimeHelpers.GetUninitializedObject(parameterType)
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ private static class ReflectionHelpers
public static bool IsBuiltInConverter(JsonConverter converter) =>
converter.GetType().Assembly == typeof(JsonConverter).Assembly;

public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;

public static Type GetElementType(JsonTypeInfo typeInfo)
{
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");
Expand Down
16 changes: 10 additions & 6 deletions src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,20 +452,24 @@ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)

bool IsNullableSchema(ref GenerationState state)
{
// A schema is marked as nullable if either
// A schema is marked as nullable if either:
// 1. We have a schema for a property where either the getter or setter are marked as nullable.
// 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
// 2. We have a schema for a Nullable<T> type.
// 3. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable.

if (propertyInfo != null || parameterInfo != null)
{
return !isNonNullableType;
}
else

if (Nullable.GetUnderlyingType(typeInfo.Type) is not null)
{
return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
return true;
}

return !typeInfo.Type.IsValueType &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,29 @@ public static void EqualFunctionCallParameters(
public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null)
=> AreJsonEquivalentValues(expected, actual, options);

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
/// <summary>
/// Asserts that the two JSON values are equal.
/// </summary>
public static void EqualJsonValues(JsonElement expectedJson, JsonElement actualJson, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
if (!JsonNode.DeepEquals(
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
JsonSerializer.SerializeToNode(expectedJson, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualJson, AIJsonUtilities.DefaultOptions)))
{
string message = propertyName is null
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}";
? $"JSON result does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}";

throw new XunitException(message);
}
}

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
EqualJsonValues(expectedElement, actualElement, propertyName);

static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options)
=> value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -854,6 +855,38 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
}

[Fact]
public async Task AIFunctionFactory_NullableParameters()
{
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);

AIFunction f = AIFunctionFactory.Create(
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
serializerOptions: JsonContext.Default.Options);

JsonElement expectedSchema = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"limit": {
"type": ["integer", "null"],
"default": null
},
"from": {
"type": ["string", "null"],
"format": "date-time",
"default": null
}
}
}
""").RootElement;

AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);

object? result = await f.InvokeAsync();
Assert.Contains("[1,1,1,1]", result?.ToString());
}

[Fact]
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
{
Expand Down Expand Up @@ -959,5 +992,7 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(StructWithDefaultCtor))]
[JsonSerializable(typeof(B))]
[JsonSerializable(typeof(int?))]
[JsonSerializable(typeof(DateTime?))]
private partial class JsonContext : JsonSerializerContext;
}
10 changes: 8 additions & 2 deletions test/Shared/JsonSchemaExporter/TestData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ internal sealed record TestData<T>(
T? Value,
[StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
IEnumerable<T?>? AdditionalValues = null,
object? ExporterOptions = null,
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
System.Text.Json.Schema.JsonSchemaExporterOptions? ExporterOptions = null,
#endif
JsonSerializerOptions? Options = null,
bool WritesNumbersAsStrings = false)
: ITestData
Expand All @@ -22,7 +24,9 @@ internal sealed record TestData<T>(

public Type Type => typeof(T);
object? ITestData.Value => Value;
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ITestData.ExporterOptions => ExporterOptions;
#endif
JsonNode ITestData.ExpectedJsonSchema { get; } =
JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
?? throw new ArgumentNullException("schema must not be null");
Expand All @@ -32,7 +36,7 @@ IEnumerable<ITestData> ITestData.GetTestDataForAllValues()
yield return this;

if (default(T) is null &&
#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
ExporterOptions is System.Text.Json.Schema.JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable: false } &&
#endif
Value is not null)
Expand All @@ -58,7 +62,9 @@ public interface ITestData

JsonNode ExpectedJsonSchema { get; }

#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ExporterOptions { get; }
#endif

JsonSerializerOptions? Options { get; }

Expand Down
47 changes: 21 additions & 26 deletions test/Shared/JsonSchemaExporter/TestTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
#if NET9_0_OR_GREATER
using System.Reflection;
#endif
using System.Text.Json;
using System.Text.Json.Nodes;
#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
using System.Text.Json.Schema;
#endif
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -135,6 +132,21 @@ public static IEnumerable<ITestData> GetTestDataCore()
}
""");

#if !NET9_0 && TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
// Regression test for https://github.com/dotnet/runtime/issues/117493
yield return new TestData<int?>(
Value: 42,
AdditionalValues: [null],
ExpectedJsonSchema: """{"type":["integer","null"]}""",
ExporterOptions: new() { TreatNullObliviousAsNonNullable = true });

yield return new TestData<DateTimeOffset?>(
Value: DateTimeOffset.MinValue,
AdditionalValues: [null],
ExpectedJsonSchema: """{"type":["string","null"],"format":"date-time"}""",
ExporterOptions: new() { TreatNullObliviousAsNonNullable = true });
#endif

// User-defined POCOs
yield return new TestData<SimplePoco>(
Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true },
Expand All @@ -152,7 +164,7 @@ public static IEnumerable<ITestData> GetTestDataCore()
}
""");

#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
// Same as above but with nullable types set to non-nullable
yield return new TestData<SimplePoco>(
Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true },
Expand Down Expand Up @@ -311,7 +323,7 @@ public static IEnumerable<ITestData> GetTestDataCore()
}
""");

#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
// Same as above but with non-nullable reference types by default.
yield return new TestData<PocoWithRecursiveMembers>(
Value: new() { Value = 1, Next = new() { Value = 2, Next = new() { Value = 3 } } },
Expand Down Expand Up @@ -761,7 +773,7 @@ of the type which points to the first occurrence. */
}
""");

#if NET9_0_OR_GREATER
#if TEST
yield return new TestData<ClassWithComponentModelAttributes>(
Value: new("string", -1),
ExpectedJsonSchema: """
Expand Down Expand Up @@ -1164,7 +1176,7 @@ public readonly struct StructDictionary<TKey, TValue>(IEnumerable<KeyValuePair<T
public int Count => _dictionary.Count;
public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key);
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() => _dictionary.GetEnumerator();
#if NETCOREAPP
#if NET
public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) => _dictionary.TryGetValue(key, out value);
#else
public bool TryGetValue(TKey key, out TValue value) => _dictionary.TryGetValue(key, out value);
Expand Down Expand Up @@ -1249,6 +1261,7 @@ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions
[JsonSerializable(typeof(IntEnum?))]
[JsonSerializable(typeof(StringEnum?))]
[JsonSerializable(typeof(SimpleRecordStruct?))]
[JsonSerializable(typeof(DateTimeOffset?))]
// User-defined POCOs
[JsonSerializable(typeof(SimplePoco))]
[JsonSerializable(typeof(SimpleRecord))]
Expand Down Expand Up @@ -1299,22 +1312,4 @@ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions
[JsonSerializable(typeof(StructDictionary<string, int>))]
[JsonSerializable(typeof(XElement))]
public partial class TestTypesContext : JsonSerializerContext;

#if NET9_0_OR_GREATER
private static TAttribute? ResolveAttribute<TAttribute>(this JsonSchemaExporterContext ctx)
where TAttribute : Attribute
{
// Resolve attributes from locations in the following order:
// 1. Property-level attributes
// 2. Parameter-level attributes and
// 3. Type-level attributes.
return
GetAttrs(ctx.PropertyInfo?.AttributeProvider) ??
GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ??
GetAttrs(ctx.TypeInfo.Type);

static TAttribute? GetAttrs(ICustomAttributeProvider? provider) =>
(TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault();
}
#endif
}
Loading