Skip to content

Commit

Permalink
utilize competitive iterator api to perform pruning
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <bowenlan23@gmail.com>
  • Loading branch information
bowenlan-amzn committed May 24, 2024
1 parent a3cb39d commit c2775ac
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,19 @@
package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
Expand All @@ -59,6 +69,7 @@
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;

Expand Down Expand Up @@ -137,8 +148,10 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
// only use ordinals if they don't increase memory usage by more than 25%
if (ordinalsMemoryUsage < countsMemoryUsage / 4) {
ordinalsCollectorsUsed++;
return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
context, ctx, fieldContext, source);
// return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
// context, ctx, fieldContext, source);
return new CompetitiveCollector(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
source, ctx, context, fieldContext);
}
ordinalsCollectorsOverheadTooHigh++;
}
Expand Down Expand Up @@ -217,6 +230,105 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl

}

private static class CompetitiveCollector extends Collector {

private final Collector delegate;
private final DisiPriorityQueue pq;

CompetitiveCollector(Collector delegate, ValuesSource.Bytes.WithOrdinals source, LeafReaderContext ctx,
SearchContext context, FieldContext fieldContext) throws IOException {
this.delegate = delegate;

final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx);
TermsEnum terms = ordinalValues.termsEnum();
Map<BytesRef, Scorer> postingMap = new HashMap<>();
while (terms.next() != null) {
BytesRef term = terms.term();

TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term));
Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f);
Scorer scorer = subWeight.scorer(ctx);

postingMap.put(term, scorer);
}
this.pq = new DisiPriorityQueue(postingMap.size());
for (Map.Entry<BytesRef, Scorer> entry : postingMap.entrySet()) {
pq.add(new DisiWrapper(entry.getValue()));
}
}

@Override
public void close() {
delegate.close();
}

@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
delegate.collect(doc, owningBucketOrd);
}

@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return new DisjunctionDISIWithPruning(pq);
}

@Override
public void postCollect() throws IOException {
delegate.postCollect();
}
}

private static class DisjunctionDISIWithPruning extends DocIdSetIterator {

final DisiPriorityQueue queue;

public DisjunctionDISIWithPruning(DisiPriorityQueue queue) {
this.queue = queue;
}

@Override
public int docID() {
return queue.top().doc;
}

@Override
public int nextDoc() throws IOException {
// don't expect this to be called
throw new UnsupportedOperationException();
}

@Override
public int advance(int target) throws IOException {
// more than advance to the next doc >= target
// we also do the pruning of current doc here

DisiWrapper top = queue.top();

// after collecting the doc, before advancing to target
// we can safely remove all the iterators that having this doc
if (top.doc != -1) {
int curTopDoc = top.doc;
do {
top.doc = top.approximation.advance(Integer.MAX_VALUE);
top = queue.updateTop();
} while (top.doc == curTopDoc);
}

if (top.doc >= target) return top.doc;
do {
top.doc = top.approximation.advance(target);
top = queue.updateTop();
} while (top.doc < target);
return top.doc;
}

@Override
public long cost() {
// don't expect this to be called
throw new UnsupportedOperationException();
}
}

/**
* Empty Collector for the Cardinality agg
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermsEnum;
Expand Down Expand Up @@ -52,7 +53,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector {
// this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator
TermsEnum terms = ordinalValues.termsEnum();
Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f);
Map<Long, Boolean> found = new HashMap<>();
Map<Long, Boolean> found = new HashMap<>(); // ord : found or not
List<Scorer> subScorers = new ArrayList<>();
while (terms.next() != null && !found.containsKey(terms.ord())) {
// TODO can we get rid of terms previously encountered in other segments?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ public void testDynamicPruningOrdinalCollector() throws IOException {

MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName);
final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName);
testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> {
testAggregation(aggregationBuilder,
new TermQuery(new Term(filterFieldName, "foo")),
iw -> {
iw.addDocument(asList(
new KeywordField(fieldName, "1", Field.Store.NO),
new KeywordField(fieldName, "2", Field.Store.NO),
Expand Down Expand Up @@ -142,10 +144,12 @@ public void testDynamicPruningOrdinalCollector() throws IOException {
new KeywordField(filterFieldName, "bar", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("5"))
));
}, card -> {
assertEquals(3.0, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, fieldType);
},
card -> {
assertEquals(3.0, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
},
fieldType);
}

public void testNoMatchingField() throws IOException {
Expand Down

0 comments on commit c2775ac

Please sign in to comment.