Skip to content

Commit 4f629bb

Browse files
authored
Add DataContent.Base64Data (#6365)
1 parent 43ca269 commit 4f629bb

File tree

5 files changed

+95
-76
lines changed

5 files changed

+95
-76
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,22 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
#if NET
6+
using System.Buffers;
7+
using System.Buffers.Text;
8+
#endif
59
using System.Diagnostics;
610
using System.Diagnostics.CodeAnalysis;
11+
#if !NET
12+
using System.Runtime.InteropServices;
13+
#endif
714
using System.Text.Json.Serialization;
815
using Microsoft.Shared.Diagnostics;
916

1017
#pragma warning disable S3996 // URI properties should not be strings
1118
#pragma warning disable CA1054 // URI-like parameters should not be strings
1219
#pragma warning disable CA1056 // URI-like properties should not be strings
20+
#pragma warning disable CA1307 // Specify StringComparison for clarity
1321

1422
namespace Microsoft.Extensions.AI;
1523

@@ -70,39 +78,35 @@ public DataContent(Uri uri, string? mediaType = null)
7078
[JsonConstructor]
7179
public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null)
7280
{
81+
// Store and validate the data URI.
7382
_uri = Throw.IfNullOrWhitespace(uri);
74-
7583
if (!uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase))
7684
{
7785
Throw.ArgumentException(nameof(uri), "The provided URI is not a data URI.");
7886
}
7987

88+
// Parse the data URI to extract the data and media type.
8089
_dataUri = DataUriParser.Parse(uri.AsMemory());
8190

91+
// Validate and store the media type.
92+
mediaType ??= _dataUri.MediaType;
8293
if (mediaType is null)
8394
{
84-
mediaType = _dataUri.MediaType;
85-
if (mediaType is null)
86-
{
87-
Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided.");
88-
}
89-
}
90-
else
91-
{
92-
if (mediaType != _dataUri.MediaType)
93-
{
94-
// If the data URI contains a media type that's different from a non-null media type
95-
// explicitly provided, prefer the one explicitly provided as an override.
96-
97-
// Extract the bytes from the data URI and null out the uri.
98-
// Then we'll lazily recreate it later if needed based on the updated media type.
99-
_data = _dataUri.ToByteArray();
100-
_dataUri = null;
101-
_uri = null;
102-
}
95+
Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided.");
10396
}
10497

10598
MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType);
99+
100+
if (!_dataUri.IsBase64 || mediaType != _dataUri.MediaType)
101+
{
102+
// In rare cases, the data URI may contain non-base64 data, in which case we
103+
// want to normalize it to base64. The supplied media type may also be different
104+
// from the one in the data URI. In either case, we extract the bytes from the data URI
105+
// and then throw away the uri; we'll recreate it lazily in the canonical form.
106+
_data = _dataUri.ToByteArray();
107+
_dataUri = null;
108+
_uri = null;
109+
}
106110
}
107111

108112
/// <summary>
@@ -134,9 +138,8 @@ public DataContent(ReadOnlyMemory<byte> data, string mediaType)
134138

135139
/// <summary>Gets the data URI for this <see cref="DataContent"/>.</summary>
136140
/// <remarks>
137-
/// The returned URI is always a valid URI string, even if the instance was constructed from a <see cref="ReadOnlyMemory{Byte}"/>
138-
/// or from a <see cref="System.Uri"/>. In the case of a <see cref="ReadOnlyMemory{T}"/>, this property returns a data URI containing
139-
/// that data.
141+
/// The returned URI is always a valid data URI string, even if the instance was constructed from a <see cref="ReadOnlyMemory{Byte}"/>
142+
/// or from a <see cref="System.Uri"/>.
140143
/// </remarks>
141144
[StringSyntax(StringSyntaxAttribute.Uri)]
142145
public string Uri
@@ -145,27 +148,26 @@ public string Uri
145148
{
146149
if (_uri is null)
147150
{
148-
if (_dataUri is null)
149-
{
150-
Debug.Assert(_data is not null, "Expected _data to be initialized.");
151-
_uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(_data.GetValueOrDefault()
152-
#if NET
153-
.Span));
154-
#else
155-
.Span.ToArray()));
156-
#endif
157-
}
158-
else
159-
{
160-
_uri = _dataUri.IsBase64 ?
151+
Debug.Assert(_data is not null, "Expected _data to be initialized.");
152+
ReadOnlyMemory<byte> data = _data.GetValueOrDefault();
153+
161154
#if NET
162-
$"data:{MediaType};base64,{_dataUri.Data.Span}" :
163-
$"data:{MediaType};,{_dataUri.Data.Span}";
155+
char[] array = ArrayPool<char>.Shared.Rent(
156+
"data:".Length + MediaType.Length + ";base64,".Length + Base64.GetMaxEncodedToUtf8Length(data.Length));
157+
158+
bool wrote = array.AsSpan().TryWrite($"data:{MediaType};base64,", out int prefixLength);
159+
wrote |= Convert.TryToBase64Chars(data.Span, array.AsSpan(prefixLength), out int dataLength);
160+
Debug.Assert(wrote, "Expected to successfully write the data URI.");
161+
_uri = array.AsSpan(0, prefixLength + dataLength).ToString();
162+
163+
ArrayPool<char>.Shared.Return(array);
164164
#else
165-
$"data:{MediaType};base64,{_dataUri.Data}" :
166-
$"data:{MediaType};,{_dataUri.Data}";
165+
string base64 = MemoryMarshal.TryGetArray(data, out ArraySegment<byte> segment) ?
166+
Convert.ToBase64String(segment.Array!, segment.Offset, segment.Count) :
167+
Convert.ToBase64String(data.ToArray());
168+
169+
_uri = $"data:{MediaType};base64,{base64}";
167170
#endif
168-
}
169171
}
170172

171173
return _uri;
@@ -205,6 +207,20 @@ public ReadOnlyMemory<byte> Data
205207
}
206208
}
207209

210+
/// <summary>Gets the data represented by this instance as a Base64 character sequence.</summary>
211+
/// <returns>The base64 representation of the data.</returns>
212+
[JsonIgnore]
213+
public ReadOnlyMemory<char> Base64Data
214+
{
215+
get
216+
{
217+
string uri = Uri;
218+
int pos = uri.IndexOf(',');
219+
Debug.Assert(pos >= 0, "Expected comma to be present in the URI.");
220+
return uri.AsMemory(pos + 1);
221+
}
222+
}
223+
208224
/// <summary>Gets a string representing this instance to display in the debugger.</summary>
209225
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
210226
private string DebuggerDisplay

src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/AIContentExtensions.cs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,6 @@ internal static bool IsImageWithSupportedFormat(this AIContent content) =>
1313
(content is UriContent uriContent && IsSupportedImageFormat(uriContent.MediaType)) ||
1414
(content is DataContent dataContent && IsSupportedImageFormat(dataContent.MediaType));
1515

16-
internal static bool IsUriBase64Encoded(this DataContent dataContent)
17-
{
18-
ReadOnlyMemory<char> uri = dataContent.Uri.AsMemory();
19-
20-
int commaIndex = uri.Span.IndexOf(',');
21-
if (commaIndex == -1)
22-
{
23-
return false;
24-
}
25-
26-
ReadOnlyMemory<char> metadata = uri.Slice(0, commaIndex);
27-
28-
bool isBase64Encoded = metadata.Span.EndsWith(";base64".AsSpan(), StringComparison.OrdinalIgnoreCase);
29-
return isBase64Encoded;
30-
}
31-
3216
private static bool IsSupportedImageFormat(string mediaType)
3317
{
3418
// 'image/jpeg' is the official MIME type for JPEG. However, some systems recognize 'image/jpg' as well.

src/Libraries/Microsoft.Extensions.AI.Evaluation.Safety/ContentSafetyServicePayloadUtilities.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -343,25 +343,13 @@ IEnumerable<JsonObject> GetContents(ChatMessage message)
343343
}
344344
else if (content is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
345345
{
346-
string url;
347-
if (dataContent.IsUriBase64Encoded())
348-
{
349-
url = dataContent.Uri;
350-
}
351-
else
352-
{
353-
BinaryData imageBytes = BinaryData.FromBytes(dataContent.Data);
354-
string base64ImageData = Convert.ToBase64String(imageBytes.ToArray());
355-
url = $"data:{dataContent.MediaType};base64,{base64ImageData}";
356-
}
357-
358346
yield return new JsonObject
359347
{
360348
["type"] = "image_url",
361349
["image_url"] =
362350
new JsonObject
363351
{
364-
["url"] = url
352+
["url"] = dataContent.Uri
365353
}
366354
};
367355
}

src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,7 @@ private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMe
402402
if (item is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
403403
{
404404
IList<string> images = currentTextMessage?.Images ?? [];
405-
images.Add(Convert.ToBase64String(dataContent.Data
406-
#if NET
407-
.Span));
408-
#else
409-
.ToArray()));
410-
#endif
405+
images.Add(dataContent.Base64Data.ToString());
411406

412407
if (currentTextMessage is not null)
413408
{

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Text;
56
using System.Text.Json;
67
using Xunit;
78

@@ -66,21 +67,27 @@ public void Ctor_ValidMediaType_Roundtrips(string mediaType)
6667
{
6768
var content = new DataContent("", mediaType);
6869
Assert.Equal(mediaType, content.MediaType);
70+
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());
6971

7072
content = new DataContent("data:,", mediaType);
7173
Assert.Equal(mediaType, content.MediaType);
74+
Assert.Equal("", content.Base64Data.ToString());
7275

7376
content = new DataContent("data:text/plain,", mediaType);
7477
Assert.Equal(mediaType, content.MediaType);
78+
Assert.Equal("", content.Base64Data.ToString());
7579

7680
content = new DataContent(new Uri("data:text/plain,"), mediaType);
7781
Assert.Equal(mediaType, content.MediaType);
82+
Assert.Equal("", content.Base64Data.ToString());
7883

7984
content = new DataContent(new byte[] { 0, 1, 2 }, mediaType);
8085
Assert.Equal(mediaType, content.MediaType);
86+
Assert.Equal("AAEC", content.Base64Data.ToString());
8187

8288
content = new DataContent(content.Uri);
8389
Assert.Equal(mediaType, content.MediaType);
90+
Assert.Equal("AAEC", content.Base64Data.ToString());
8491
}
8592

8693
[Fact]
@@ -91,10 +98,12 @@ public void Ctor_NoMediaType_Roundtrips()
9198
content = new DataContent("");
9299
Assert.Equal("", content.Uri);
93100
Assert.Equal("image/png", content.MediaType);
101+
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());
94102

95103
content = new DataContent(new Uri(""));
96104
Assert.Equal("", content.Uri);
97105
Assert.Equal("image/png", content.MediaType);
106+
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());
98107
}
99108

100109
[Fact]
@@ -128,6 +137,7 @@ public void Deserialize_MatchesExpectedData()
128137

129138
Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri);
130139
Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray());
140+
Assert.Equal("AQIDBA==", content.Base64Data.ToString());
131141
Assert.Equal("application/octet-stream", content.MediaType);
132142

133143
// Uri referenced content-only
@@ -150,6 +160,7 @@ public void Deserialize_MatchesExpectedData()
150160

151161
Assert.Equal("data:audio/wav;base64,AQIDBA==", content.Uri);
152162
Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray());
163+
Assert.Equal("AQIDBA==", content.Base64Data.ToString());
153164
Assert.Equal("audio/wav", content.MediaType);
154165
Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString());
155166
}
@@ -224,4 +235,29 @@ public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix)
224235
var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType);
225236
Assert.False(content.HasTopLevelMediaType(prefix));
226237
}
238+
239+
[Fact]
240+
public void Data_Roundtrips()
241+
{
242+
Random rand = new(42);
243+
for (int length = 0; length < 100; length++)
244+
{
245+
byte[] data = new byte[length];
246+
rand.NextBytes(data);
247+
248+
var content = new DataContent(data, "application/octet-stream");
249+
Assert.Equal(data, content.Data.ToArray());
250+
Assert.Equal(Convert.ToBase64String(data), content.Base64Data.ToString());
251+
Assert.Equal($"data:application/octet-stream;base64,{Convert.ToBase64String(data)}", content.Uri);
252+
}
253+
}
254+
255+
[Fact]
256+
public void NonBase64Data_Normalized()
257+
{
258+
var content = new DataContent("data:text/plain,hello world");
259+
Assert.Equal("data:text/plain;base64,aGVsbG8gd29ybGQ=", content.Uri);
260+
Assert.Equal("aGVsbG8gd29ybGQ=", content.Base64Data.ToString());
261+
Assert.Equal("hello world", Encoding.ASCII.GetString(content.Data.ToArray()));
262+
}
227263
}

0 commit comments

Comments
 (0)