diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs index f686db781..6598f961d 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Text.RegularExpressions; @@ -30,8 +29,7 @@ namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch; /// * support custom schema /// * support custom Azure AI Search logic /// -[Experimental("KMEXP03")] -public sealed class AzureAISearchMemory : IMemoryDb +public class AzureAISearchMemory : IMemoryDb, IMemoryDbBatchUpsert { private readonly ITextEmbeddingGenerator _embeddingGenerator; private readonly ILogger _log; @@ -101,6 +99,8 @@ public AzureAISearchMemory( /// public Task CreateIndexAsync(string index, int vectorSize, CancellationToken cancellationToken = default) { + // Vectors cannot be less than 2 - TODO: use different index schema + vectorSize = Math.Max(2, vectorSize); return this.CreateIndexAsync(index, AzureAISearchMemoryRecord.GetSchema(vectorSize), cancellationToken); } @@ -126,14 +126,25 @@ public Task DeleteIndexAsync(string index, CancellationToken cancellationToken = /// public async Task UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default) + { + var result = this.BatchUpsertAsync(index, new[] { record }, cancellationToken); + var id = await result.SingleAsync(cancellationToken).ConfigureAwait(false); + return id; + } + + /// + public async IAsyncEnumerable BatchUpsertAsync( + string index, + IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var client = this.GetSearchClient(index); - AzureAISearchMemoryRecord localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record); + var localRecords = records.Select(AzureAISearchMemoryRecord.FromMemoryRecord); try { await client.IndexDocumentsAsync( - IndexDocumentsBatch.Upload(new[] { localRecord }), + IndexDocumentsBatch.Upload(localRecords), new IndexDocumentsOptions { ThrowOnAnyError = true }, cancellationToken: cancellationToken).ConfigureAwait(false); } @@ -142,7 +153,10 @@ await client.IndexDocumentsAsync( throw new IndexNotFound(e.Message, e); } - return record.Id; + foreach (var record in records) + { + yield return record.Id; + } } /// @@ -354,25 +368,6 @@ private async Task DoesIndexExistAsync(string index, CancellationToken can return false; } - private async IAsyncEnumerable UpsertBatchAsync( - string index, - IEnumerable records, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var client = this.GetSearchClient(index); - - foreach (MemoryRecord record in records) - { - var localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record); - await client.IndexDocumentsAsync( - IndexDocumentsBatch.Upload(new[] { localRecord }), - new IndexDocumentsOptions { ThrowOnAnyError = true }, - cancellationToken: cancellationToken).ConfigureAwait(false); - - yield return record.Id; - } - } - /// /// Index names cannot contain special chars. We use this rule to replace a few common ones /// with an underscore and reduce the chance of errors. If other special chars are used, we leave it @@ -625,7 +620,7 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options, private static double ScoreToCosineSimilarity(double score) { - return 2 - 1 / score; + return 2 - (1 / score); } private static double CosineSimilarityToScore(double similarity) diff --git a/service/Core/Configuration/KernelMemoryConfig.cs b/service/Core/Configuration/KernelMemoryConfig.cs index 559bbf517..a1373e075 100644 --- a/service/Core/Configuration/KernelMemoryConfig.cs +++ b/service/Core/Configuration/KernelMemoryConfig.cs @@ -54,6 +54,12 @@ public class DistributedOrchestrationConfig /// public List MemoryDbTypes { get; set; } = new(); + /// + /// How many memory DB records to insert at once when extracting memories + /// from uploaded documents (used only if the Memory Db supports batching). + /// + public int MemoryDbUpsertBatchSize { get; set; } = 1; + /// /// The OCR service used to recognize text in images. /// diff --git a/service/Core/Handlers/SaveRecordsHandler.cs b/service/Core/Handlers/SaveRecordsHandler.cs index 8a0c5400e..6d5f03e54 100644 --- a/service/Core/Handlers/SaveRecordsHandler.cs +++ b/service/Core/Handlers/SaveRecordsHandler.cs @@ -45,8 +45,12 @@ private static string GetRecordId(string documentId, string partId) private readonly IPipelineOrchestrator _orchestrator; private readonly List _memoryDbs; + private readonly List _memoryDbsWithSingleUpsert; + private readonly List _memoryDbsWithBatchUpsert; private readonly ILogger _log; private readonly bool _embeddingGenerationEnabled; + private readonly int _upsertBatchSize; + private readonly bool _usingBatchUpsert; /// public string StepName { get; } @@ -57,10 +61,12 @@ private static string GetRecordId(string documentId, string partId) /// /// Pipeline step for which the handler will be invoked /// Current orchestrator used by the pipeline, giving access to content and other helps. + /// Configuration settings /// Application logger public SaveRecordsHandler( string stepName, IPipelineOrchestrator orchestrator, + KernelMemoryConfig? config = null, ILogger? log = null) { this.StepName = stepName; @@ -70,6 +76,8 @@ public SaveRecordsHandler( this._orchestrator = orchestrator; this._memoryDbs = orchestrator.GetMemoryDbs(); + this._upsertBatchSize = (config ?? new KernelMemoryConfig()).DataIngestion.MemoryDbUpsertBatchSize; + if (this._memoryDbs.Count < 1) { this._log.LogError("Handler {0} NOT ready, no memory DB configured", stepName); @@ -78,6 +86,20 @@ public SaveRecordsHandler( { this._log.LogInformation("Handler {0} ready, {1} vector storages", stepName, this._memoryDbs.Count); } + + // Ideally we want to call MarkProcessedBy(this) after storing each memory record, to avoid unnecessary + // duplicate upserts in case of transient errors. However, if there's a DB supporting batch upserts + // this optimization is not available (without further refactoring, marking each file for each memory DB). + // Here we split the list of DBs in two lists, those supporting batching and those not, prioritizing + // the single upsert if possible, to have the best retry strategy when possible. + this._memoryDbsWithSingleUpsert = this._memoryDbs; + this._memoryDbsWithBatchUpsert = new List(); + if (this._upsertBatchSize > 1) + { + this._memoryDbsWithSingleUpsert = this._memoryDbs.Where(x => x is not IMemoryDbBatchUpsert).ToList(); + this._memoryDbsWithBatchUpsert = this._memoryDbs.Where(x => x is IMemoryDbBatchUpsert).ToList(); + this._usingBatchUpsert = this._memoryDbsWithBatchUpsert.Count > 0; + } } /// @@ -89,172 +111,187 @@ public SaveRecordsHandler( await this.DeletePreviousRecordsAsync(pipeline, cancellationToken).ConfigureAwait(false); pipeline.PreviousExecutionsToPurge = new List(); - return this._embeddingGenerationEnabled - ? await this.SaveEmbeddingsAsync(pipeline, cancellationToken).ConfigureAwait(false) - : await this.SavePartitionsAsync(pipeline, cancellationToken).ConfigureAwait(false); - } - - /// - /// Loop through all the EMBEDDINGS generated, creating a memory record for each one - /// - public async Task<(bool success, DataPipeline updatedPipeline)> SaveEmbeddingsAsync( - DataPipeline pipeline, CancellationToken cancellationToken = default) - { - var embeddingsFound = false; + var recordsFound = false; // TODO: replace with ConditionalWeakTable indexing on this._memoryDbs var createdIndexes = new HashSet(); - // For each embedding file => For each Memory DB => Upsert record - foreach (FileDetailsWithRecordId embeddingFile in GetListOfEmbeddingFiles(pipeline)) + // Case 1 (_embeddingGenerationEnabled = true): Loop through all the EMBEDDINGS generated, creating a memory record for each one + // Case 2 (_embeddingGenerationEnabled = false): Loop through all the PARTITIONS and SYNTHETIC chunks, creating a memory record for each one + var sourceFiles = this._embeddingGenerationEnabled + ? GetListOfEmbeddingFiles(pipeline).Chunk(this._upsertBatchSize) + : GetListOfPartitionAndSyntheticFiles(pipeline).Chunk(this._upsertBatchSize); + + foreach (FileDetailsWithRecordId[] files in sourceFiles) { - if (embeddingFile.File.AlreadyProcessedBy(this)) - { - embeddingsFound = true; - this._log.LogTrace("File {0} already processed by this handler", embeddingFile.File.Name); - continue; - } + if (files.Length == 0) { continue; } - string vectorJson = await this._orchestrator.ReadTextFileAsync(pipeline, embeddingFile.File.Name, cancellationToken).ConfigureAwait(false); - EmbeddingFileContent? embeddingData = JsonSerializer.Deserialize(vectorJson.RemoveBOM().Trim()); - if (embeddingData == null) + // List of records to upsert, used only when batching + var records = new List(); + foreach (FileDetailsWithRecordId file in files) { - throw new OrchestrationException($"Unable to deserialize embedding file {embeddingFile.File.Name}"); - } + if (file.File.AlreadyProcessedBy(this)) + { + recordsFound = true; + this._log.LogTrace("File {0} already processed by this handler", file.File.Name); + continue; + } - embeddingsFound = true; - - DataPipeline.FileDetails fileDetails = pipeline.GetFile(embeddingFile.File.ParentId); - string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, embeddingData.SourceFileName, cancellationToken).ConfigureAwait(false); - string url = await this.GetSourceUrlAsync(pipeline, fileDetails, cancellationToken).ConfigureAwait(false); - - var record = PrepareRecord( - pipeline: pipeline, - recordId: embeddingFile.RecordId, - fileName: fileDetails.Name, - url: url, - fileId: embeddingFile.File.ParentId, - partitionFileId: embeddingFile.File.SourcePartitionId, - partitionContent: partitionContent, - partitionNumber: embeddingFile.File.PartitionNumber, - sectionNumber: embeddingFile.File.SectionNumber, - partitionEmbedding: embeddingData.Vector, - embeddingGeneratorProvider: embeddingData.GeneratorProvider, - embeddingGeneratorName: embeddingData.GeneratorName, - embeddingFile.File.Tags); - - foreach (IMemoryDb client in this._memoryDbs) - { - try + MemoryRecord record; + DataPipeline.FileDetails fileDetails = pipeline.GetFile(file.File.ParentId); + + // Get source URL (only for web pages) + string webPageUrl = await this.GetSourceUrlAsync(pipeline, fileDetails, cancellationToken).ConfigureAwait(false); + + if (this._embeddingGenerationEnabled) { - await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); + recordsFound = true; + + // Read vector data from embedding file + string vectorJson = await this._orchestrator.ReadTextFileAsync(pipeline, file.File.Name, cancellationToken).ConfigureAwait(false); + EmbeddingFileContent? embeddingData = JsonSerializer.Deserialize(vectorJson.RemoveBOM().Trim()); + if (embeddingData == null) { throw new OrchestrationException($"Unable to deserialize embedding file {file.File.Name}"); } + + // Get text partition content + string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, embeddingData.SourceFileName, cancellationToken).ConfigureAwait(false); - this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + // Prepare record, including embedding details + record = PrepareRecord( + pipeline: pipeline, + recordId: file.RecordId, + fileName: fileDetails.Name, + url: webPageUrl, + fileId: file.File.ParentId, + partitionFileId: file.File.SourcePartitionId, + partitionContent: partitionContent, + partitionNumber: file.File.PartitionNumber, + sectionNumber: file.File.SectionNumber, + partitionEmbedding: embeddingData.Vector, + embeddingGeneratorProvider: embeddingData.GeneratorProvider, + embeddingGeneratorName: embeddingData.GeneratorName, + file.File.Tags); } - catch (IndexNotFound e) + else { - this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); - await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken, true).ConfigureAwait(false); + switch (file.File.MimeType) + { + case MimeTypes.PlainText: + case MimeTypes.MarkDown: + recordsFound = true; + + // Get text partition content + string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, file.File.Name, cancellationToken).ConfigureAwait(false); + + // Prepare record, without embedding data + record = PrepareRecord( + pipeline: pipeline, + recordId: file.RecordId, + fileName: fileDetails.Name, + url: webPageUrl, + fileId: file.File.ParentId, + partitionFileId: file.File.Id, + partitionContent: partitionContent, + partitionNumber: fileDetails.PartitionNumber, + sectionNumber: fileDetails.SectionNumber, + partitionEmbedding: new Embedding(), + embeddingGeneratorProvider: "", + embeddingGeneratorName: "", + file.File.Tags); + break; + + default: + this._log.LogWarning("File {0} cannot be used to generate embedding, type not supported", file.File.Name); + // skip record + continue; + } + } + + records.Add(record); - this._log.LogTrace("Retry: Saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + foreach (IMemoryDb db in this._memoryDbsWithSingleUpsert) + { + await this.CreateIndexOnceAsync(db, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); + await this.SaveRecordAsync(pipeline, db, record, createdIndexes, cancellationToken).ConfigureAwait(false); } + + // If possible mark the file as processed now, so in case of retries it won't be processed again + if (!this._usingBatchUpsert) { file.File.MarkProcessedBy(this); } } - embeddingFile.File.MarkProcessedBy(this); + if (this._usingBatchUpsert) + { + if (records.Count > 0) + { + foreach (IMemoryDb db in this._memoryDbsWithBatchUpsert) + { + await this.CreateIndexOnceAsync(db, createdIndexes, pipeline.Index, records[0].Vector.Length, cancellationToken).ConfigureAwait(false); + await this.SaveRecordsBatchAsync(pipeline, db, records, createdIndexes, cancellationToken).ConfigureAwait(false); + } + } + + foreach (FileDetailsWithRecordId file in files) + { + file.File.MarkProcessedBy(this); + } + } } - if (!embeddingsFound) + if (!recordsFound) { - this._log.LogWarning("Pipeline '{0}/{1}': embeddings not found, cannot save embeddings, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId); + this._log.LogWarning("Pipeline '{0}/{1}': step {2}: no records found, cannot save, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId, this.StepName); } return (true, pipeline); } - /// - /// Loop through all the PARTITIONS and SYNTHETIC chunks, creating a memory record for each one - /// - public async Task<(bool success, DataPipeline updatedPipeline)> SavePartitionsAsync( - DataPipeline pipeline, CancellationToken cancellationToken = default) + private static IEnumerable GetListOfEmbeddingFiles(DataPipeline pipeline) { - var partitionsFound = false; + return pipeline.Files.SelectMany(f1 => f1.GeneratedFiles.Where( + f2 => f2.Value.ArtifactType == DataPipeline.ArtifactTypes.TextEmbeddingVector) + .Select(x => new FileDetailsWithRecordId(pipeline, x.Value))); + } - // TODO: replace with ConditionalWeakTable indexing on this._memoryDbs - var createdIndexes = new HashSet(); + private static IEnumerable GetListOfPartitionAndSyntheticFiles(DataPipeline pipeline) + { + return pipeline.Files.SelectMany(f1 => f1.GeneratedFiles.Where( + f2 => f2.Value.ArtifactType is DataPipeline.ArtifactTypes.TextPartition or DataPipeline.ArtifactTypes.SyntheticData) + .Select(x => new FileDetailsWithRecordId(pipeline, x.Value))); + } - // Create records only for partitions (text chunks) and synthetic data - foreach (FileDetailsWithRecordId file in GetListOfPartitionAndSyntheticFiles(pipeline)) + private async Task SaveRecordAsync(DataPipeline pipeline, IMemoryDb db, MemoryRecord record, HashSet createdIndexes, CancellationToken cancellationToken) + { + try { - if (file.File.AlreadyProcessedBy(this)) - { - partitionsFound = true; - this._log.LogTrace("File {0} already processed by this handler", file.File.Name); - continue; - } - - partitionsFound = true; - - switch (file.File.MimeType) - { - case MimeTypes.PlainText: - case MimeTypes.MarkDown: - - DataPipeline.FileDetails partitionFileDetails = pipeline.GetFile(file.File.ParentId); - string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, file.File.Name, cancellationToken).ConfigureAwait(false); - string url = await this.GetSourceUrlAsync(pipeline, partitionFileDetails, cancellationToken).ConfigureAwait(false); - - var record = PrepareRecord( - pipeline: pipeline, - recordId: file.RecordId, - fileName: partitionFileDetails.Name, - url: url, - fileId: file.File.ParentId, - partitionFileId: file.File.Id, - partitionContent: partitionContent, - partitionNumber: partitionFileDetails.PartitionNumber, - sectionNumber: partitionFileDetails.SectionNumber, - partitionEmbedding: new Embedding(), - embeddingGeneratorProvider: "", - embeddingGeneratorName: "", - file.File.Tags); - - foreach (IMemoryDb client in this._memoryDbs) - { - try - { - await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); - - this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); - } - catch (IndexNotFound e) - { - this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); - await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken, true).ConfigureAwait(false); - - this._log.LogTrace("Retry: saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); - } - } - - break; - - default: - this._log.LogWarning("File {0} cannot be used to generate embedding, type not supported", file.File.Name); - continue; - } + this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); + await db.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); + await this.CreateIndexOnceAsync(db, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken, true).ConfigureAwait(false); - file.File.MarkProcessedBy(this); + this._log.LogTrace("Retry: saving record {0} in index '{1}'", record.Id, pipeline.Index); + await db.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); } + } - if (!partitionsFound) + private async Task SaveRecordsBatchAsync(DataPipeline pipeline, IMemoryDb db, List records, HashSet createdIndexes, CancellationToken cancellationToken) + { + var dbBatch = ((IMemoryDbBatchUpsert)db); + ArgumentNullExceptionEx.ThrowIfNull(dbBatch, nameof(dbBatch), $"{db.GetType().FullName} doesn't implement {nameof(IMemoryDbBatchUpsert)}"); + try { - this._log.LogWarning("Pipeline '{0}/{1}': partitions and synthetic records not found, cannot save, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId); + this._log.LogTrace("Saving batch of {0} records in index '{1}'", records.Count, pipeline.Index); + await dbBatch.BatchUpsertAsync(pipeline.Index, records, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); + await this.CreateIndexOnceAsync(db, createdIndexes, pipeline.Index, records[0].Vector.Length, cancellationToken, true).ConfigureAwait(false); - return (true, pipeline); + this._log.LogTrace("Retry: Saving batch of {0} records in index '{1}'", records.Count, pipeline.Index); + await dbBatch.BatchUpsertAsync(pipeline.Index, records, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + } } private async Task DeletePreviousRecordsAsync(DataPipeline pipeline, CancellationToken cancellationToken) @@ -285,20 +322,6 @@ private async Task DeletePreviousRecordsAsync(DataPipeline pipeline, Cancellatio } } - private static IEnumerable GetListOfEmbeddingFiles(DataPipeline pipeline) - { - return pipeline.Files.SelectMany(f1 => f1.GeneratedFiles.Where( - f2 => f2.Value.ArtifactType == DataPipeline.ArtifactTypes.TextEmbeddingVector) - .Select(x => new FileDetailsWithRecordId(pipeline, x.Value))); - } - - private static IEnumerable GetListOfPartitionAndSyntheticFiles(DataPipeline pipeline) - { - return pipeline.Files.SelectMany(f1 => f1.GeneratedFiles.Where( - f2 => f2.Value.ArtifactType == DataPipeline.ArtifactTypes.TextPartition || f2.Value.ArtifactType == DataPipeline.ArtifactTypes.SyntheticData) - .Select(x => new FileDetailsWithRecordId(pipeline, x.Value))); - } - private async Task CreateIndexOnceAsync( IMemoryDb client, HashSet createdIndexes, diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index 4a9c07526..9eb47ff8a 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -144,6 +144,9 @@ "MemoryDbTypes": [ "SimpleVectorDb" ], + // How many memory DB records to insert at once when extracting memories from + // uploaded documents (used only if the Memory Db supports batching). + "MemoryDbUpsertBatchSize": 1, // "None" or "AzureAIDocIntel" "ImageOcrType": "None", // Partitioning / Chunking settings