Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent>

case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
return ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf");

case AIContent when content.RawRepresentation is ChatMessageContentPart rawContentPart:
return rawContentPart;
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using OpenAI.Responses;
using static Microsoft.Extensions.AI.OpenAIChatClient;

#pragma warning disable S907 // "goto" statement should not be used
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable S3604 // Member initializer values should not be redundant
Expand Down Expand Up @@ -87,12 +88,13 @@ public async Task<ChatResponse> GetResponseAsync(
// Convert and return the results.
ChatResponse response = new()
{
ResponseId = openAIResponse.Id,
ConversationId = openAIResponse.Id,
CreatedAt = openAIResponse.CreatedAt,
FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason),
Messages = [new(ChatRole.Assistant, [])],
ModelId = openAIResponse.Model,
RawRepresentation = openAIResponse,
ResponseId = openAIResponse.Id,
Usage = ToUsageDetails(openAIResponse),
};

Expand Down Expand Up @@ -125,12 +127,20 @@ public async Task<ChatResponse> GetResponseAsync(

case FunctionCallResponseItem functionCall:
response.FinishReason ??= ChatFinishReason.ToolCalls;
message.Contents.Add(
FunctionCallContent.CreateFromParsedArguments(
functionCall.FunctionArguments.ToMemory(),
functionCall.CallId,
functionCall.FunctionName,
static json => JsonSerializer.Deserialize(json.Span, ResponseClientJsonContext.Default.IDictionaryStringObject)!));
var fcc = FunctionCallContent.CreateFromParsedArguments(
functionCall.FunctionArguments.ToMemory(),
functionCall.CallId,
functionCall.FunctionName,
static json => JsonSerializer.Deserialize(json.Span, ResponseClientJsonContext.Default.IDictionaryStringObject)!);
fcc.RawRepresentation = outputItem;
message.Contents.Add(fcc);
break;

default:
message.Contents.Add(new()
{
RawRepresentation = outputItem,
});
break;
}
}
Expand Down Expand Up @@ -170,20 +180,21 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
createdAt = createdUpdate.Response.CreatedAt;
responseId = createdUpdate.Response.Id;
modelId = createdUpdate.Response.Model;
break;
goto default;

case StreamingResponseCompletedUpdate completedUpdate:
yield return new()
{
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
CreatedAt = createdAt,
ResponseId = responseId,
ConversationId = responseId,
FinishReason =
ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ??
(functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop),
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
ConversationId = responseId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
Role = lastRole,
};
break;
Expand All @@ -200,23 +211,24 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
break;
}

break;
goto default;

case StreamingResponseOutputItemDoneUpdate outputItemDoneUpdate:
_ = outputIndexToMessages.Remove(outputItemDoneUpdate.OutputIndex);
break;
goto default;

case StreamingResponseOutputTextDeltaUpdate outputTextDeltaUpdate:
_ = outputIndexToMessages.TryGetValue(outputTextDeltaUpdate.OutputIndex, out MessageResponseItem? messageItem);
lastMessageId = messageItem?.Id;
lastRole = ToChatRole(messageItem?.Role);
yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta)
{
ConversationId = responseId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
ConversationId = responseId,
};
break;

Expand All @@ -227,7 +239,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
_ = (callInfo.Arguments ??= new()).Append(functionCallArgumentsDeltaUpdate.Delta);
}

break;
goto default;
}

case StreamingResponseFunctionCallArgumentsDoneUpdate functionCallOutputDoneUpdate:
Expand All @@ -246,25 +258,23 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
lastRole = ChatRole.Assistant;
yield return new ChatResponseUpdate(lastRole, [fci])
{
ConversationId = responseId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
ConversationId = responseId,
};

break;
}

break;
goto default;
}

case StreamingResponseErrorUpdate errorUpdate:
yield return new ChatResponseUpdate
{
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
ResponseId = responseId,
Role = lastRole,
ConversationId = responseId,
Contents =
[
Expand All @@ -274,6 +284,12 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
Details = errorUpdate.Param,
}
],
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
Role = lastRole,
};
break;

Expand All @@ -283,12 +299,26 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
Role = lastRole,
ConversationId = responseId,
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
};
break;

default:
yield return new ChatResponseUpdate
{
ConversationId = responseId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
Role = lastRole,
};
break;
}
}
}
Expand Down Expand Up @@ -487,6 +517,10 @@ private static IEnumerable<ResponseItem> ToOpenAIResponseItems(
callContent.Arguments,
AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object?>)))));
break;

case AIContent when item.RawRepresentation is ResponseItem rawRep:
yield return rawRep;
break;
}
}

Expand Down Expand Up @@ -530,11 +564,25 @@ private static List<AIContent> ToAIContents(IEnumerable<ResponseContentPart> con
switch (part.Kind)
{
case ResponseContentPartKind.OutputText:
results.Add(new TextContent(part.Text));
results.Add(new TextContent(part.Text)
{
RawRepresentation = part,
});
break;

case ResponseContentPartKind.Refusal:
results.Add(new ErrorContent(part.Refusal) { ErrorCode = nameof(ResponseContentPartKind.Refusal) });
results.Add(new ErrorContent(part.Refusal)
{
ErrorCode = nameof(ResponseContentPartKind.Refusal),
RawRepresentation = part,
});
break;

default:
results.Add(new()
{
RawRepresentation = part,
});
break;
}
}
Expand Down Expand Up @@ -570,6 +618,10 @@ private static List<ResponseContentPart> ToOpenAIResponsesContent(IList<AIConten
case ErrorContent errorContent when errorContent.ErrorCode == nameof(ResponseContentPartKind.Refusal):
parts.Add(ResponseContentPart.CreateRefusalPart(errorContent.Message));
break;

case AIContent when content.RawRepresentation is ResponseContentPart rawRep:
parts.Add(rawRep);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,24 @@ public async Task BasicRequestResponse_Streaming()
Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text)));

var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_741_892_091);
Assert.Equal(10, updates.Count);
Assert.Equal(17, updates.Count);

for (int i = 0; i < updates.Count; i++)
{
Assert.Equal("resp_67d329fbc87c81919f8952fe71dafc96029dabe3ee19bb77", updates[i].ResponseId);
Assert.Equal("resp_67d329fbc87c81919f8952fe71dafc96029dabe3ee19bb77", updates[i].ConversationId);
Assert.Equal(createdAt, updates[i].CreatedAt);
Assert.Equal("gpt-4o-mini-2024-07-18", updates[i].ModelId);
Assert.Equal(ChatRole.Assistant, updates[i].Role);
Assert.Null(updates[i].AdditionalProperties);
Assert.Equal(i == 10 ? 0 : 1, updates[i].Contents.Count);
Assert.Equal((i >= 4 && i <= 12) || i == 16 ? 1 : 0, updates[i].Contents.Count);
Assert.Equal(i < updates.Count - 1 ? null : ChatFinishReason.Stop, updates[i].FinishReason);
}

for (int i = 4; i < updates.Count; i++)
{
Assert.Equal(ChatRole.Assistant, updates[i].Role);
}

UsageContent usage = updates.SelectMany(u => u.Contents).OfType<UsageContent>().Single();
Assert.Equal(26, usage.Details.InputTokenCount);
Assert.Equal(10, usage.Details.OutputTokenCount);
Expand Down
Loading