Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.RegularExpressions;
Expand All @@ -30,8 +29,7 @@ namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch;
/// * support custom schema
/// * support custom Azure AI Search logic
/// </summary>
[Experimental("KMEXP03")]
public sealed class AzureAISearchMemory : IMemoryDb
public class AzureAISearchMemory : IMemoryDb, IMemoryDbBatchUpsert
{
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly ILogger<AzureAISearchMemory> _log;
Expand Down Expand Up @@ -101,6 +99,8 @@ public AzureAISearchMemory(
/// <inheritdoc />
public Task CreateIndexAsync(string index, int vectorSize, CancellationToken cancellationToken = default)
{
// Vectors cannot be less than 2 - TODO: use different index schema
vectorSize = Math.Max(2, vectorSize);
return this.CreateIndexAsync(index, AzureAISearchMemoryRecord.GetSchema(vectorSize), cancellationToken);
}

Expand All @@ -126,14 +126,25 @@ public Task DeleteIndexAsync(string index, CancellationToken cancellationToken =

/// <inheritdoc />
public async Task<string> UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default)
{
var result = this.BatchUpsertAsync(index, new[] { record }, cancellationToken);
var id = await result.SingleAsync(cancellationToken).ConfigureAwait(false);
return id;
}

/// <inheritdoc />
public async IAsyncEnumerable<string> BatchUpsertAsync(
string index,
IEnumerable<MemoryRecord> records,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var client = this.GetSearchClient(index);
AzureAISearchMemoryRecord localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record);
var localRecords = records.Select(AzureAISearchMemoryRecord.FromMemoryRecord);

try
{
await client.IndexDocumentsAsync(
IndexDocumentsBatch.Upload(new[] { localRecord }),
IndexDocumentsBatch.Upload(localRecords),
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken: cancellationToken).ConfigureAwait(false);
}
Expand All @@ -142,7 +153,10 @@ await client.IndexDocumentsAsync(
throw new IndexNotFound(e.Message, e);
}

return record.Id;
foreach (var record in records)
{
yield return record.Id;
}
}

/// <inheritdoc />
Expand Down Expand Up @@ -354,25 +368,6 @@ private async Task<bool> DoesIndexExistAsync(string index, CancellationToken can
return false;
}

private async IAsyncEnumerable<string> UpsertBatchAsync(
string index,
IEnumerable<MemoryRecord> records,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var client = this.GetSearchClient(index);

foreach (MemoryRecord record in records)
{
var localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record);
await client.IndexDocumentsAsync(
IndexDocumentsBatch.Upload(new[] { localRecord }),
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken: cancellationToken).ConfigureAwait(false);

yield return record.Id;
}
}

/// <summary>
/// Index names cannot contain special chars. We use this rule to replace a few common ones
/// with an underscore and reduce the chance of errors. If other special chars are used, we leave it
Expand Down Expand Up @@ -625,7 +620,7 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options,

private static double ScoreToCosineSimilarity(double score)
{
return 2 - 1 / score;
return 2 - (1 / score);
}

private static double CosineSimilarityToScore(double similarity)
Expand Down
6 changes: 6 additions & 0 deletions service/Core/Configuration/KernelMemoryConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public class DistributedOrchestrationConfig
/// </summary>
public List<string> MemoryDbTypes { get; set; } = new();

/// <summary>
/// How many memory DB records to insert at once when extracting memories
/// from uploaded documents (used only if the Memory Db supports batching).
/// </summary>
public int MemoryDbUpsertBatchSize { get; set; } = 1;

/// <summary>
/// The OCR service used to recognize text in images.
/// </summary>
Expand Down
Loading