@@ -401,16 +401,14 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
401401 ITrace trace ,
402402 CancellationToken cancellationToken )
403403 {
404- // Collate results should return an IEnumerable<HybridSearchQueryResult>
405404 // Sort and coalesce the results on _rid
406405 // After sorting, each HybridSearchQueryResult has a fixed index in the list
407406 // This index can be used as the key for the ranking array
408407 // Now create an array (per dimension) of tuples (score, index) and sort it by score
409- // The index of the tuple in the sorted array is the rank of the document
410- // Create an array of tuples of ranks for each dimension
408+ // We can use these sorted arrays to compute the ranks. Identical scores get the same rank
411409 // Create an array of tuples of (RRF scores, index) for each document using the ranks
410+ // Use the ranks array to compute the RRF scores
412411 // Sort the array by RRF scores
413- // Emit the documents in the sorted order by using the index in the tuple
414412
415413 TryCatch < ( List < HybridSearchQueryResult > queryResults , QueryPage emptyPage ) > tryGetResults = await PrefetchInParallelAsync (
416414 queryPipelineStages ,
@@ -433,7 +431,7 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
433431
434432 queryResults . Sort ( ( x , y ) => string . CompareOrdinal ( x . Rid . Value , y . Rid . Value ) ) ;
435433
436- UniqueRids ( queryResults ) ;
434+ CoalesceDuplicateRids ( queryResults ) ;
437435
438436 TryCatch < IReadOnlyList < List < ScoreTuple > > > tryGetComponentScores = RetrieveComponentScores ( queryResults , queryPipelineStages . Count ) ;
439437 if ( tryGetComponentScores . Failed )
@@ -450,7 +448,7 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
450448
451449 int [ , ] ranks = ComputeRanks ( componentScores ) ;
452450
453- ComputeRRFScores ( ranks , queryResults ) ;
451+ ComputeRrfScores ( ranks , queryResults ) ;
454452
455453 HybridSearchDebugTraceHelpers . TraceQueryResultsWithRanks ( queryResults , ranks ) ;
456454
@@ -475,26 +473,39 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
475473
476474 await ParallelPrefetch . PrefetchInParallelAsync ( prefetchers , maxConcurrency , trace , cancellationToken ) ;
477475
478- double requestCharge = 0 ;
476+ int queryResultCount = 0 ;
477+ List < IReadOnlyList < QueryPage > > prefetchedPageLists = new List < IReadOnlyList < QueryPage > > ( prefetchers . Count ) ;
479478 QueryPageParameters queryPageParameters = null ;
480- List < HybridSearchQueryResult > queryResults = new List < HybridSearchQueryResult > ( ) ;
481479 foreach ( QueryPipelineStagePrefetcher prefetcher in prefetchers )
482480 {
483- TryCatch < IReadOnlyList < QueryPage > > tryGetResults = await prefetcher . GetResultAsync ( trace , cancellationToken ) ;
481+ TryCatch < ( IReadOnlyList < QueryPage > , int ) > tryGetResults = await prefetcher . GetResultAsync ( trace , cancellationToken ) ;
484482 if ( tryGetResults . Failed )
485483 {
486484 return TryCatch < ( List < HybridSearchQueryResult > queryResults , QueryPage emptyPage ) > . FromException ( tryGetResults . Exception ) ;
487485 }
488486
489- foreach ( QueryPage queryPage in tryGetResults . Result )
487+ ( IReadOnlyList < QueryPage > queryPages , int documentCount ) = tryGetResults . Result ;
488+ prefetchedPageLists . Add ( queryPages ) ;
489+ queryResultCount += documentCount ;
490+
491+ if ( queryPageParameters == null && queryPages . Count > 0 )
490492 {
491- requestCharge += queryPage . RequestCharge ;
492- queryPageParameters ?? = new QueryPageParameters (
493+ QueryPage queryPage = queryPages [ 0 ] ;
494+ queryPageParameters = new QueryPageParameters (
493495 activityId : queryPage . ActivityId ,
494496 cosmosQueryExecutionInfo : queryPage . CosmosQueryExecutionInfo ,
495497 distributionPlanSpec : queryPage . DistributionPlanSpec ,
496498 additionalHeaders : queryPage . AdditionalHeaders ) ;
499+ }
500+ }
497501
502+ List < HybridSearchQueryResult > queryResults = new List < HybridSearchQueryResult > ( queryResultCount ) ;
503+ double requestCharge = 0 ;
504+ foreach ( IReadOnlyList < QueryPage > queryPages in prefetchedPageLists )
505+ {
506+ foreach ( QueryPage queryPage in queryPages )
507+ {
508+ requestCharge += queryPage . RequestCharge ;
498509 foreach ( CosmosElement document in queryPage . Documents )
499510 {
500511 HybridSearchQueryResult hybridSearchQueryResult = HybridSearchQueryResult . Create ( document ) ;
@@ -517,7 +528,7 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
517528 return TryCatch < ( List < HybridSearchQueryResult > queryResults , QueryPage emptyPage ) > . FromResult ( ( queryResults , emptyPage ) ) ;
518529 }
519530
520- private static void UniqueRids ( List < HybridSearchQueryResult > queryResults )
531+ private static void CoalesceDuplicateRids ( List < HybridSearchQueryResult > queryResults )
521532 {
522533 int writeIndex = 0 ;
523534 for ( int readIndex = 1 ; readIndex < queryResults . Count ; ++ readIndex )
@@ -581,7 +592,7 @@ private static TryCatch<IReadOnlyList<List<ScoreTuple>>> RetrieveComponentScores
581592 return ranks ;
582593 }
583594
584- private static void ComputeRRFScores (
595+ private static void ComputeRrfScores (
585596 int [ , ] ranks ,
586597 List < HybridSearchQueryResult > queryResults )
587598 {
@@ -729,7 +740,7 @@ private sealed class QueryPipelineStagePrefetcher : IPrefetcher
729740 {
730741 private readonly IQueryPipelineStage queryPipelineStage ;
731742
732- private TryCatch < IReadOnlyList < QueryPage > > result ;
743+ private TryCatch < ( IReadOnlyList < QueryPage > , int ) > result ;
733744
734745 private bool prefetched ;
735746
@@ -740,25 +751,27 @@ public QueryPipelineStagePrefetcher(IQueryPipelineStage queryPipelineStage)
740751
741752 public async ValueTask PrefetchAsync ( ITrace trace , CancellationToken cancellationToken )
742753 {
743- List < QueryPage > result = new List < QueryPage > ( ) ;
754+ int documentCount = 0 ;
755+ List < QueryPage > pages = new List < QueryPage > ( ) ;
744756 while ( await this . queryPipelineStage . MoveNextAsync ( trace , cancellationToken ) )
745757 {
746758 TryCatch < QueryPage > tryCatchQueryPage = this . queryPipelineStage . Current ;
747759 if ( tryCatchQueryPage . Failed )
748760 {
749- this . result = TryCatch < IReadOnlyList < QueryPage > > . FromException ( tryCatchQueryPage . Exception ) ;
761+ this . result = TryCatch < ( IReadOnlyList < QueryPage > , int ) > . FromException ( tryCatchQueryPage . Exception ) ;
750762 this . prefetched = true ;
751763 return ;
752764 }
753765
754- result . Add ( tryCatchQueryPage . Result ) ;
766+ pages . Add ( tryCatchQueryPage . Result ) ;
767+ documentCount += tryCatchQueryPage . Result . Documents . Count ;
755768 }
756769
757- this . result = TryCatch < IReadOnlyList < QueryPage > > . FromResult ( result ) ;
770+ this . result = TryCatch < ( IReadOnlyList < QueryPage > , int ) > . FromResult ( ( pages , documentCount ) ) ;
758771 this . prefetched = true ;
759772 }
760773
761- public async ValueTask < TryCatch < IReadOnlyList < QueryPage > > > GetResultAsync ( ITrace trace , CancellationToken cancellationToken )
774+ public async ValueTask < TryCatch < ( IReadOnlyList < QueryPage > , int ) > > GetResultAsync ( ITrace trace , CancellationToken cancellationToken )
762775 {
763776 if ( ! this . prefetched )
764777 {
0 commit comments