Skip to content

Commit

Permalink
Fix Faiss efficient filter exact search using byte vector datatype (#…
Browse files Browse the repository at this point in the history
…2165)

* Fix Faiss efficient filter exact search using byte vector datatype

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda authored Sep 30, 2024
1 parent f16f225 commit 6d098cf
Show file tree
Hide file tree
Showing 11 changed files with 690 additions and 42 deletions.
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
Expand Down Expand Up @@ -111,21 +114,35 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
return new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(
return new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided
*/
public class BinaryVectorIdsKNNIterator implements KNNIterator {
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public BinaryVectorIdsKNNIterator(
@Nullable final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) throws IOException {
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
// nextDoc should already be referring to next knnVectorValues
this.docId = getNextDocId();
}

public BinaryVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType)
throws IOException {
this(null, queryVector, binaryVectorValues, spaceType);
}

/**
* Advance to the next doc and update score value with score of the next doc.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next doc id
*/
@Override
public int nextDoc() throws IOException {

if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
currentScore = computeScore();
int currentDocId = docId;
docId = getNextDocId();
return currentDocId;
}

@Override
public float score() {
return currentScore;
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
}

protected int getNextDocId() throws IOException {
if (bitSetIterator == null) {
return binaryVectorValues.nextDoc();
}
int nextDocID = this.bitSetIterator.nextDoc();
// For filter case, advance vector values to corresponding doc id from filter bit set
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) {
binaryVectorValues.advance(nextDocID);
}
return nextDocID;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -22,30 +22,30 @@
*/
public class ByteVectorIdsKNNIterator implements KNNIterator {
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final float[] queryVector;
protected final KNNByteVectorValues byteVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public ByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType
) throws IOException {
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.byteVectorValues = byteVectorValues;
this.spaceType = spaceType;
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
// nextDoc should already be referring to next knnVectorValues
this.docId = getNextDocId();
}

public ByteVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType)
public ByteVectorIdsKNNIterator(final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType)
throws IOException {
this(null, queryVector, binaryVectorValues, spaceType);
this(null, queryVector, byteVectorValues, spaceType);
}

/**
Expand All @@ -72,20 +72,30 @@ public float score() {
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
final byte[] vector = byteVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);

// The query vector of Faiss byte vector is a Float array because ScalarQuantizer accepts it as float array.
// To compute the score between this query vector and each vector in KNNByteVectorValues we are casting this query vector into byte
// array directly.
// This is safe to do so because float query vector already has validated byte values. Do not reuse this direct cast at any other
// place.
final byte[] byteQueryVector = new byte[queryVector.length];
for (int i = 0; i < queryVector.length; i++) {
byteQueryVector[i] = (byte) queryVector[i];
}
return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector);
}

protected int getNextDocId() throws IOException {
if (bitSetIterator == null) {
return binaryVectorValues.nextDoc();
return byteVectorValues.nextDoc();
}
int nextDocID = this.bitSetIterator.nextDoc();
// For filter case, advance vector values to corresponding doc id from filter bit set
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) {
binaryVectorValues.advance(nextDocID);
byteVectorValues.advance(nextDocID);
}
return nextDocID;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator {
private final BitSet parentBitSet;

public NestedBinaryVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

public NestedBinaryVectorIdsKNNIterator(
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(null, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
*/
@Override
public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}

currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;

// In order to traverse all children for given parent, we have to use docId < parentId, because,
// kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child
// and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4]
// and parentBitSet will have [1,5]
// Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
currentScore = score;
}
docId = getNextDocId();
}

return bestChild;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -23,18 +23,18 @@ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator {

public NestedByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
super(filterIdsArray, queryVector, byteVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

public NestedByteVectorIdsKNNIterator(
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
Expand Down
Loading

0 comments on commit 6d098cf

Please sign in to comment.