Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 67 additions & 36 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,39 @@ namespace Microsoft.KernelMemory;

#pragma warning disable CA2234 // using string URIs is ok

/// <summary>
/// Kernel Memory web service client
/// </summary>
public sealed class MemoryWebClient : IKernelMemory
{
private static readonly JsonSerializerOptions s_caseInsensitiveJsonOptions = new() { PropertyNameCaseInsensitive = true };

private readonly HttpClient _client;

/// <summary>
/// New instance of web client to use Kernel Memory web service
/// </summary>
/// <param name="endpoint">Kernel Memory web service endpoint</param>
/// <param name="apiKey">Kernel Memory web service API Key (if configured)</param>
/// <param name="apiKeyHeader">Name of HTTP header to use to send API Key</param>
public MemoryWebClient(string endpoint, string? apiKey = "", string apiKeyHeader = "Authorization")
: this(endpoint, new HttpClient(), apiKey: apiKey, apiKeyHeader: apiKeyHeader)
{
}

/// <summary>
/// New instance of web client to use Kernel Memory web service
/// </summary>
/// <param name="endpoint">Kernel Memory web service endpoint</param>
/// <param name="client">Custom HTTP Client to use (note: BaseAddress is overwritten)</param>
/// <param name="apiKey">Kernel Memory web service API Key (if configured)</param>
/// <param name="apiKeyHeader">Name of HTTP header to use to send API Key</param>
public MemoryWebClient(string endpoint, HttpClient client, string? apiKey = "", string apiKeyHeader = "Authorization")
{
ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(endpoint, nameof(endpoint), "Kernel Memory endpoint is empty");

this._client = client;
this._client.BaseAddress = new Uri(endpoint);
this._client.BaseAddress = new Uri(endpoint.CleanBaseAddress());

if (!string.IsNullOrEmpty(apiKey))
{
Expand Down Expand Up @@ -101,16 +119,19 @@ public async Task<string> ImportTextAsync(
IEnumerable<string>? steps = null,
CancellationToken cancellationToken = default)
{
using Stream content = new MemoryStream(Encoding.UTF8.GetBytes(text));
return await this.ImportDocumentAsync(
content,
fileName: "content.txt",
documentId: documentId,
tags,
index: index,
steps: steps,
cancellationToken)
.ConfigureAwait(false);
Stream content = new MemoryStream(Encoding.UTF8.GetBytes(text));
await using (content.ConfigureAwait(false))
{
return await this.ImportDocumentAsync(
content,
fileName: "content.txt",
documentId: documentId,
tags,
index: index,
steps: steps,
cancellationToken)
.ConfigureAwait(false);
}
}

/// <inheritdoc />
Expand All @@ -125,23 +146,26 @@ public async Task<string> ImportWebPageAsync(
var uri = new Uri(url);
Verify.ValidateUrl(uri.AbsoluteUri, requireHttps: false, allowReservedIp: false, allowQuery: true);

using Stream content = new MemoryStream(Encoding.UTF8.GetBytes(uri.AbsoluteUri));
return await this.ImportDocumentAsync(
content,
fileName: "content.url",
documentId: documentId,
tags,
index: index,
steps: steps,
cancellationToken)
.ConfigureAwait(false);
Stream content = new MemoryStream(Encoding.UTF8.GetBytes(uri.AbsoluteUri));
await using (content.ConfigureAwait(false))
{
return await this.ImportDocumentAsync(
content,
fileName: "content.url",
documentId: documentId,
tags,
index: index,
steps: steps,
cancellationToken)
.ConfigureAwait(false);
}
}

/// <inheritdoc />
public async Task<IEnumerable<IndexDetails>> ListIndexesAsync(CancellationToken cancellationToken = default)
{
const string URL = Constants.HttpIndexesEndpoint;
HttpResponseMessage? response = await this._client.GetAsync(URL, cancellationToken).ConfigureAwait(false);
var url = Constants.HttpIndexesEndpoint.CleanUrlPath();
HttpResponseMessage response = await this._client.GetAsync(url, cancellationToken).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

Expand All @@ -154,8 +178,10 @@ public async Task<IEnumerable<IndexDetails>> ListIndexesAsync(CancellationToken
/// <inheritdoc />
public async Task DeleteIndexAsync(string? index = null, CancellationToken cancellationToken = default)
{
var url = Constants.HttpDeleteIndexEndpointWithParams.Replace(Constants.HttpIndexPlaceholder, index, StringComparison.OrdinalIgnoreCase);
HttpResponseMessage? response = await this._client.DeleteAsync(url, cancellationToken).ConfigureAwait(false);
var url = Constants.HttpDeleteIndexEndpointWithParams
.Replace(Constants.HttpIndexPlaceholder, index, StringComparison.OrdinalIgnoreCase)
.CleanUrlPath();
HttpResponseMessage response = await this._client.DeleteAsync(url, cancellationToken).ConfigureAwait(false);

// No error if the index doesn't exist
if (response.StatusCode == HttpStatusCode.NotFound)
Expand Down Expand Up @@ -183,8 +209,9 @@ public async Task DeleteDocumentAsync(string documentId, string? index = null, C

var url = Constants.HttpDeleteDocumentEndpointWithParams
.Replace(Constants.HttpIndexPlaceholder, index, StringComparison.OrdinalIgnoreCase)
.Replace(Constants.HttpDocumentIdPlaceholder, documentId, StringComparison.OrdinalIgnoreCase);
HttpResponseMessage? response = await this._client.DeleteAsync(url, cancellationToken).ConfigureAwait(false);
.Replace(Constants.HttpDocumentIdPlaceholder, documentId, StringComparison.OrdinalIgnoreCase)
.CleanUrlPath();
HttpResponseMessage response = await this._client.DeleteAsync(url, cancellationToken).ConfigureAwait(false);

// No error if the document doesn't exist
if (response.StatusCode == HttpStatusCode.NotFound)
Expand Down Expand Up @@ -220,7 +247,8 @@ public async Task<bool> IsDocumentReadyAsync(
{
var url = Constants.HttpUploadStatusEndpointWithParams
.Replace(Constants.HttpIndexPlaceholder, index, StringComparison.OrdinalIgnoreCase)
.Replace(Constants.HttpDocumentIdPlaceholder, documentId, StringComparison.OrdinalIgnoreCase);
.Replace(Constants.HttpDocumentIdPlaceholder, documentId, StringComparison.OrdinalIgnoreCase)
.CleanUrlPath();
HttpResponseMessage? response = await this._client.GetAsync(url, cancellationToken).ConfigureAwait(false);
if (response.StatusCode == HttpStatusCode.NotFound)
{
Expand All @@ -242,12 +270,13 @@ public async Task<StreamableFileContent> ExportFileAsync(
string? index = null,
CancellationToken cancellationToken = default)
{
string requestUri = Constants.HttpDownloadEndpointWithParams
var url = Constants.HttpDownloadEndpointWithParams
.Replace(Constants.HttpIndexPlaceholder, index, StringComparison.OrdinalIgnoreCase)
.Replace(Constants.HttpDocumentIdPlaceholder, documentId, StringComparison.OrdinalIgnoreCase)
.Replace(Constants.HttpFilenamePlaceholder, fileName, StringComparison.OrdinalIgnoreCase);
.Replace(Constants.HttpFilenamePlaceholder, fileName, StringComparison.OrdinalIgnoreCase)
.CleanUrlPath();

HttpResponseMessage httpResponse = await this._client.GetAsync(requestUri, cancellationToken).ConfigureAwait(false);
HttpResponseMessage httpResponse = await this._client.GetAsync(url, cancellationToken).ConfigureAwait(false);
ArgumentNullExceptionEx.ThrowIfNull(httpResponse, nameof(httpResponse), "KernelMemory HTTP response is NULL");

httpResponse.EnsureSuccessStatusCode();
Expand Down Expand Up @@ -288,7 +317,8 @@ public async Task<SearchResult> SearchAsync(
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");

HttpResponseMessage? response = await this._client.PostAsync(Constants.HttpSearchEndpoint, content, cancellationToken).ConfigureAwait(false);
var url = Constants.HttpSearchEndpoint.CleanUrlPath();
HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();

var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -320,7 +350,8 @@ public async Task<MemoryAnswer> AskAsync(
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");

HttpResponseMessage? response = await this._client.PostAsync(Constants.HttpAskEndpoint, content, cancellationToken).ConfigureAwait(false);
var url = Constants.HttpAskEndpoint.CleanUrlPath();
HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();

var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -382,7 +413,7 @@ private async Task<string> ImportInternalAsync(
using StringContent indexContent = new(index);
using (StringContent documentIdContent = new(uploadRequest.DocumentId))
{
List<IDisposable> disposables = new();
List<IDisposable> disposables = [];
formData.Add(indexContent, Constants.WebServiceIndexField);
formData.Add(documentIdContent, Constants.WebServiceDocumentIdField);

Expand Down Expand Up @@ -422,7 +453,8 @@ private async Task<string> ImportInternalAsync(
// Send HTTP request
try
{
HttpResponseMessage? response = await this._client.PostAsync("/upload", formData, cancellationToken).ConfigureAwait(false);
var url = Constants.HttpUploadEndpoint.CleanUrlPath();
HttpResponseMessage response = await this._client.PostAsync(url, formData, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
}
catch (HttpRequestException e) when (e.Data.Contains("StatusCode"))
Expand All @@ -435,7 +467,6 @@ private async Task<string> ImportInternalAsync(
}
finally
{
formData.Dispose();
foreach (var disposable in disposables)
{
disposable.Dispose();
Expand Down
20 changes: 20 additions & 0 deletions clients/dotnet/WebClient/StringExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.KernelMemory;

internal static class StringExtensions
{
public static string CleanBaseAddress(this string endpoint)
{
ArgumentNullExceptionEx.ThrowIfNull(endpoint, nameof(endpoint), "Kernel Memory API endpoint is NULL");

return endpoint.TrimEnd('/') + '/';
}

public static string CleanUrlPath(this string path)
{
if (string.IsNullOrWhiteSpace(path)) { path = "/"; }

return path.TrimStart('/');
}
}