Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Improvements

* GITHUB#15453: Avoid unnecessary sorting and instantiations in readMapOfStrings. (Benjamin Lerer)

* GITHUB#15565: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao)

Optimizations
---------------------
* GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public abstract class AllGroupHeadsCollector<T> extends SimpleCollector {

private final GroupSelector<T> groupSelector;
protected final Sort sort;
protected final boolean fillSortValues;

protected final int[] reversed;
protected final int compIDXEnd;
Expand All @@ -62,15 +63,29 @@ public abstract class AllGroupHeadsCollector<T> extends SimpleCollector {
* @param <T> the group value type
*/
public static <T> AllGroupHeadsCollector<T> newCollector(GroupSelector<T> selector, Sort sort) {
return newCollector(selector, sort, false);
}

/**
* Create a new AllGroupHeadsCollector based on the type of within-group Sort required
*
* @param selector a GroupSelector to define the groups
* @param sort the within-group sort to use to choose the group head document
* @param fillSortValues whether to store sort values for merging across collectors
* @param <T> the group value type
*/
public static <T> AllGroupHeadsCollector<T> newCollector(
GroupSelector<T> selector, Sort sort, boolean fillSortValues) {
if (sort.equals(Sort.RELEVANCE)) {
return new ScoringGroupHeadsCollector<>(selector, sort);
return new ScoringGroupHeadsCollector<>(selector, sort, fillSortValues);
}
return new SortingGroupHeadsCollector<>(selector, sort);
return new SortingGroupHeadsCollector<>(selector, sort, fillSortValues);
}

private AllGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
private AllGroupHeadsCollector(GroupSelector<T> selector, Sort sort, boolean fillSortValues) {
this.groupSelector = selector;
this.sort = sort;
this.fillSortValues = fillSortValues;
this.reversed = new int[sort.getSort().length];
final SortField[] sortFields = sort.getSort();
for (int i = 0; i < sortFields.length; i++) {
Expand Down Expand Up @@ -126,6 +141,17 @@ protected Collection<? extends GroupHead<T>> getCollectedGroupHeads() {
return heads.values();
}

/**
* Returns the sort values for a given group.
*
* @param groupValue the group value
* @return the sort values, or null if not available
*/
public Object[] getSortValues(T groupValue) {
GroupHead<T> head = heads.get(groupValue);
return head != null ? head.getSortValues() : null;
}

@Override
public void collect(int doc) throws IOException {
groupSelector.advanceTo(doc);
Expand Down Expand Up @@ -232,40 +258,60 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException {
* @throws IOException If I/O related errors occur
*/
protected abstract void updateDocHead(int doc) throws IOException;

/**
* Returns the sort values for this group head.
*
* @return the sort values, or null if not stored
*/
protected Object[] getSortValues() {
return null;
}
}

/** General implementation using a {@link FieldComparator} to select the group head */
private static class SortingGroupHeadsCollector<T> extends AllGroupHeadsCollector<T> {

protected SortingGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
super(selector, sort);
protected SortingGroupHeadsCollector(
GroupSelector<T> selector, Sort sort, boolean fillSortValues) {
super(selector, sort, fillSortValues);
}

@Override
protected GroupHead<T> newGroupHead(int doc, T value, LeafReaderContext ctx, Scorable scorer)
throws IOException {
return new SortingGroupHead<>(sort, value, doc, ctx, scorer);
return new SortingGroupHead<>(sort, value, doc, ctx, scorer, fillSortValues);
}
}

private static class SortingGroupHead<T> extends GroupHead<T> {

final FieldComparator[] comparators;
final LeafFieldComparator[] leafComparators;
final Object[] sortValues;

protected SortingGroupHead(
Sort sort, T groupValue, int doc, LeafReaderContext context, Scorable scorer)
Sort sort,
T groupValue,
int doc,
LeafReaderContext context,
Scorable scorer,
boolean fillSortValues)
throws IOException {
super(groupValue, doc, context.docBase);
final SortField[] sortFields = sort.getSort();
comparators = new FieldComparator[sortFields.length];
leafComparators = new LeafFieldComparator[sortFields.length];
sortValues = fillSortValues ? new Object[sortFields.length] : null;
for (int i = 0; i < sortFields.length; i++) {
comparators[i] = sortFields[i].getComparator(1, Pruning.NONE);
leafComparators[i] = comparators[i].getLeafComparator(context);
leafComparators[i].setScorer(scorer);
leafComparators[i].copy(0, doc);
leafComparators[i].setBottom(0);
if (fillSortValues) {
sortValues[i] = comparators[i].value(0);
}
}
}

Expand All @@ -291,38 +337,50 @@ public int compare(int compIDX, int doc) throws IOException {

@Override
public void updateDocHead(int doc) throws IOException {
for (LeafFieldComparator comparator : leafComparators) {
comparator.copy(0, doc);
comparator.setBottom(0);
for (int i = 0; i < leafComparators.length; i++) {
leafComparators[i].copy(0, doc);
leafComparators[i].setBottom(0);
if (sortValues != null) {
sortValues[i] = comparators[i].value(0);
}
}
this.doc = doc + docBase;
}

@Override
protected Object[] getSortValues() {
return sortValues;
}
}

/** Specialized implementation for sorting by score */
private static class ScoringGroupHeadsCollector<T> extends AllGroupHeadsCollector<T> {

protected ScoringGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
super(selector, sort);
protected ScoringGroupHeadsCollector(
GroupSelector<T> selector, Sort sort, boolean fillSortValues) {
super(selector, sort, fillSortValues);
}

@Override
protected GroupHead<T> newGroupHead(
int doc, T value, LeafReaderContext context, Scorable scorer) throws IOException {
return new ScoringGroupHead<>(scorer, value, doc, context.docBase);
return new ScoringGroupHead<>(scorer, value, doc, context.docBase, fillSortValues);
}
}

private static class ScoringGroupHead<T> extends GroupHead<T> {

private Scorable scorer;
private float topScore;
private final Object[] sortValues;

protected ScoringGroupHead(Scorable scorer, T groupValue, int doc, int docBase)
protected ScoringGroupHead(
Scorable scorer, T groupValue, int doc, int docBase, boolean fillSortValues)
throws IOException {
super(groupValue, doc, docBase);
this.scorer = scorer;
this.topScore = scorer.score();
this.sortValues = fillSortValues ? new Object[] {topScore} : null;
}

@Override
Expand All @@ -344,6 +402,14 @@ protected int compare(int compIDX, int doc) throws IOException {
@Override
protected void updateDocHead(int doc) throws IOException {
this.doc = doc + docBase;
if (sortValues != null) {
sortValues[0] = topScore;
}
}

@Override
protected Object[] getSortValues() {
return sortValues;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

/**
* A CollectorManager implementation for AllGroupHeadsCollector.
*
* @lucene.experimental
*/
public class AllGroupHeadsCollectorManager
implements CollectorManager<
AllGroupHeadsCollector<?>, AllGroupHeadsCollectorManager.GroupHeadsResult> {

/** Result wrapper that allows retrieving group heads as int[] or Bits. */
public static class GroupHeadsResult {
private final int[] groupHeads;

GroupHeadsResult(int[] groupHeads) {
this.groupHeads = groupHeads;
}

public int[] retrieveGroupHeads() {
return groupHeads;
}

public Bits retrieveGroupHeads(int maxDoc) {
FixedBitSet result = new FixedBitSet(maxDoc);
for (int docId : groupHeads) {
result.set(docId);
}
return result;
}
}

private static class GroupHeadWithValues {
int doc;
final Object[] sortValues;

GroupHeadWithValues(int doc, Object[] sortValues) {
this.doc = doc;
this.sortValues = sortValues;
}
}

private final String groupField;
private final ValueSource valueSource;
private final Map<Object, Object> valueSourceContext;
private final Sort sortWithinGroup;

/** Creates a new AllGroupHeadsCollectorManager for TermGroupSelector. */
public AllGroupHeadsCollectorManager(String groupField, Sort sortWithinGroup) {
this.groupField = groupField;
this.valueSource = null;
this.valueSourceContext = null;
this.sortWithinGroup = sortWithinGroup;
}

/** Creates a new AllGroupHeadsCollectorManager for ValueSourceGroupSelector. */
public AllGroupHeadsCollectorManager(
ValueSource valueSource, Map<Object, Object> valueSourceContext, Sort sortWithinGroup) {
this.groupField = null;
this.valueSource = valueSource;
this.valueSourceContext = valueSourceContext;
this.sortWithinGroup = sortWithinGroup;
}

@Override
public AllGroupHeadsCollector<?> newCollector() throws IOException {
GroupSelector<?> newGroupSelector;
if (groupField != null) {
newGroupSelector = new TermGroupSelector(groupField);
} else {
newGroupSelector = new ValueSourceGroupSelector(valueSource, valueSourceContext);
}

return AllGroupHeadsCollector.newCollector(newGroupSelector, sortWithinGroup, true);
}

@Override
public GroupHeadsResult reduce(Collection<AllGroupHeadsCollector<?>> collectors) {
if (collectors.isEmpty()) {
return new GroupHeadsResult(new int[0]);
}

if (collectors.size() == 1) {
return new GroupHeadsResult(collectors.iterator().next().retrieveGroupHeads());
}

Map<Object, GroupHeadWithValues> mergedHeads = new HashMap<>();
SortField[] sortFields = sortWithinGroup.getSort();

for (AllGroupHeadsCollector<?> collector : collectors) {
mergeCollectorHeads(collector, mergedHeads, sortFields);
}

return new GroupHeadsResult(mergedHeads.values().stream().mapToInt(h -> h.doc).toArray());
}

@SuppressWarnings("unchecked")
private <T> void mergeCollectorHeads(
AllGroupHeadsCollector<T> collector,
Map<Object, GroupHeadWithValues> mergedHeads,
SortField[] sortFields) {
Collection<AllGroupHeadsCollector.GroupHead<T>> heads =
(Collection<AllGroupHeadsCollector.GroupHead<T>>) collector.getCollectedGroupHeads();
for (AllGroupHeadsCollector.GroupHead<T> head : heads) {
Object[] sortValues = collector.getSortValues(head.groupValue);
GroupHeadWithValues existing = mergedHeads.get(head.groupValue);
if (existing == null) {
mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues));
} else if (sortValues != null && existing.sortValues != null) {
int cmp = compareValues(sortValues, existing.sortValues, sortFields);
if (cmp > 0 || (cmp == 0 && head.doc < existing.doc)) {
mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues));
}
}
}
}

@SuppressWarnings({"unchecked", "rawtypes"})
private int compareValues(Object[] values1, Object[] values2, SortField[] sortFields) {
for (int i = 0; i < sortFields.length; i++) {
int cmp = 0;
if (values1[i] == null) {
cmp = values2[i] == null ? 0 : -1;
} else if (values2[i] == null) {
cmp = 1;
} else if (values1[i] instanceof Comparable) {
cmp = ((Comparable) values1[i]).compareTo(values2[i]);
}
if (cmp != 0) {
// For SCORE type, natural order is descending (higher is better)
// For other types, natural order is ascending (lower is better)
// reverse=true flips the natural order
boolean naturalDescending = sortFields[i].getType() == SortField.Type.SCORE;
boolean wantDescending = naturalDescending != sortFields[i].getReverse();
return wantDescending ? cmp : -cmp;
}
}
return 0;
}
}
Loading