From 6d098cf5266160033eccc7ee2e90bd4e4c1c071c Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 30 Sep 2024 17:24:28 -0500 Subject: [PATCH] Fix Faiss efficient filter exact search using byte vector datatype (#2165) * Fix Faiss efficient filter exact search using byte vector datatype Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda --- .../knn/index/query/ExactSearcher.java | 21 ++- .../iterators/BinaryVectorIdsKNNIterator.java | 92 +++++++++++ .../iterators/ByteVectorIdsKNNIterator.java | 34 ++-- .../NestedBinaryVectorIdsKNNIterator.java | 77 +++++++++ .../NestedByteVectorIdsKNNIterator.java | 12 +- .../BinaryVectorIdsKNNIteratorTests.java | 97 +++++++++++ .../ByteVectorIdsKNNIteratorTests.java | 24 +-- ...NestedBinaryVectorIdsKNNIteratorTests.java | 91 ++++++++++ .../NestedByteVectorIdsKNNIteratorTests.java | 24 +-- .../knn/integ/FilteredSearchByteIT.java | 104 ++++++++++++ .../knn/integ/NestedSearchByteIT.java | 156 ++++++++++++++++++ 11 files changed, 690 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java create mode 100644 src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java create mode 100644 src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java create mode 100644 src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java create mode 100644 src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 193cba8c1..8e5849abb 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -20,12 +20,15 @@ import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; +import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.KNNIterator; import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -111,7 +114,7 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); if (isNestedRequired) { - return new NestedByteVectorIdsKNNIterator( + return new NestedBinaryVectorIdsKNNIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, @@ -119,13 +122,27 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea knnQuery.getParentsFilter().getBitSet(leafReaderContext) ); } - return new ByteVectorIdsKNNIterator( + return new BinaryVectorIdsKNNIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, spaceType ); } + + if (VectorDataType.BYTE == knnQuery.getVectorDataType()) { + final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + if (isNestedRequired) { + return new NestedByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } + return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType); + } final byte[] quantizedQueryVector; final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java new file mode 100644 index 000000000..5bab5b573 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.opensearch.common.Nullable; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene + * https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 + * + * The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided + */ +public class BinaryVectorIdsKNNIterator implements KNNIterator { + protected final BitSetIterator bitSetIterator; + protected final byte[] queryVector; + protected final KNNBinaryVectorValues binaryVectorValues; + protected final SpaceType spaceType; + protected float currentScore = Float.NEGATIVE_INFINITY; + protected int docId; + + public BinaryVectorIdsKNNIterator( + @Nullable final BitSet filterIdsBitSet, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType + ) throws IOException { + this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.queryVector = queryVector; + this.binaryVectorValues = binaryVectorValues; + this.spaceType = spaceType; + // This cannot be moved inside nextDoc() method since it will break when we have nested field, where + // nextDoc should already be referring to next knnVectorValues + this.docId = getNextDocId(); + } + + public BinaryVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType) + throws IOException { + this(null, queryVector, binaryVectorValues, spaceType); + } + + /** + * Advance to the next doc and update score value with score of the next doc. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next doc id + */ + @Override + public int nextDoc() throws IOException { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + currentScore = computeScore(); + int currentDocId = docId; + docId = getNextDocId(); + return currentDocId; + } + + @Override + public float score() { + return currentScore; + } + + protected float computeScore() throws IOException { + final byte[] vector = binaryVectorValues.getVector(); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + } + + protected int getNextDocId() throws IOException { + if (bitSetIterator == null) { + return binaryVectorValues.nextDoc(); + } + int nextDocID = this.bitSetIterator.nextDoc(); + // For filter case, advance vector values to corresponding doc id from filter bit set + if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { + binaryVectorValues.advance(nextDocID); + } + return nextDocID; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java index b1aea4284..0e8005163 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java @@ -10,7 +10,7 @@ import org.apache.lucene.util.BitSetIterator; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -22,30 +22,30 @@ */ public class ByteVectorIdsKNNIterator implements KNNIterator { protected final BitSetIterator bitSetIterator; - protected final byte[] queryVector; - protected final KNNBinaryVectorValues binaryVectorValues; + protected final float[] queryVector; + protected final KNNByteVectorValues byteVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; public ByteVectorIdsKNNIterator( @Nullable final BitSet filterIdsBitSet, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType ) throws IOException { this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; - this.binaryVectorValues = binaryVectorValues; + this.byteVectorValues = byteVectorValues; this.spaceType = spaceType; // This cannot be moved inside nextDoc() method since it will break when we have nested field, where // nextDoc should already be referring to next knnVectorValues this.docId = getNextDocId(); } - public ByteVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType) + public ByteVectorIdsKNNIterator(final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType) throws IOException { - this(null, queryVector, binaryVectorValues, spaceType); + this(null, queryVector, byteVectorValues, spaceType); } /** @@ -72,20 +72,30 @@ public float score() { } protected float computeScore() throws IOException { - final byte[] vector = binaryVectorValues.getVector(); + final byte[] vector = byteVectorValues.getVector(); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. - return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + + // The query vector of Faiss byte vector is a Float array because ScalarQuantizer accepts it as float array. + // To compute the score between this query vector and each vector in KNNByteVectorValues we are casting this query vector into byte + // array directly. + // This is safe to do so because float query vector already has validated byte values. Do not reuse this direct cast at any other + // place. + final byte[] byteQueryVector = new byte[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + byteQueryVector[i] = (byte) queryVector[i]; + } + return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector); } protected int getNextDocId() throws IOException { if (bitSetIterator == null) { - return binaryVectorValues.nextDoc(); + return byteVectorValues.nextDoc(); } int nextDocID = this.bitSetIterator.nextDoc(); // For filter case, advance vector values to corresponding doc id from filter bit set if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { - binaryVectorValues.advance(nextDocID); + byteVectorValues.advance(nextDocID); } return nextDocID; } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java new file mode 100644 index 000000000..97bf3517e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.common.Nullable; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs. + * However, it dedupe docs per each parent doc + * of which ID is set in parentBitSet and only return best child doc with the highest score. + */ +public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator { + private final BitSet parentBitSet; + + public NestedBinaryVectorIdsKNNIterator( + @Nullable final BitSet filterIdsArray, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet + ) throws IOException { + super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + this.parentBitSet = parentBitSet; + } + + public NestedBinaryVectorIdsKNNIterator( + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet + ) throws IOException { + super(null, queryVector, binaryVectorValues, spaceType); + this.parentBitSet = parentBitSet; + } + + /** + * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next best child doc id + */ + @Override + public int nextDoc() throws IOException { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + currentScore = Float.NEGATIVE_INFINITY; + int currentParent = parentBitSet.nextSetBit(docId); + int bestChild = -1; + + // In order to traverse all children for given parent, we have to use docId < parentId, because, + // kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child + // and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4] + // and parentBitSet will have [1,5] + // Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + float score = computeScore(); + if (score > currentScore) { + bestChild = docId; + currentScore = score; + } + docId = getNextDocId(); + } + + return bestChild; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java index 3c93ec888..9644b620f 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java @@ -9,7 +9,7 @@ import org.apache.lucene.util.BitSet; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -23,18 +23,18 @@ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator { public NestedByteVectorIdsKNNIterator( @Nullable final BitSet filterIdsArray, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + super(filterIdsArray, queryVector, byteVectorValues, spaceType); this.parentBitSet = parentBitSet; } public NestedByteVectorIdsKNNIterator( - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues binaryVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java new file mode 100644 index 000000000..6d5dffa98 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.mockito.stubbing.OngoingStubbing; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BinaryVectorIdsKNNIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 1, 2, 3 }; + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); + for (int i = 0; i < filterIds.length; i++) { + assertEquals(filterIds[i], iterator.nextDoc()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } + + @SneakyThrows + public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOException { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final List dataVectors = Arrays.asList( + new byte[] { 11, 12, 13 }, + new byte[] { 14, 15, 16 }, + new byte[] { 17, 18, 19 }, + new byte[] { 20, 21, 22 }, + new byte[] { 23, 24, 25 } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn( + dataVectors.get(0), + dataVectors.get(1), + dataVectors.get(2), + dataVectors.get(3), + dataVectors.get(4) + ); + + // stub return value when nextDoc is called + OngoingStubbing stubbing = when(values.nextDoc()); + for (int i = 0; i < dataVectors.size(); i++) { + stubbing = stubbing.thenReturn(i); + } + // set last return to be Integer.MAX_VALUE to represent no more docs + stubbing.thenReturn(Integer.MAX_VALUE); + + // Execute and verify + BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator(queryVector, values, spaceType); + for (int i = 0; i < dataVectors.size(); i++) { + assertEquals(i, iterator.nextDoc()); + assertEquals(expectedScores.get(i), iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + verify(values, never()).advance(anyInt()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java index 0b1b71286..60169b95f 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.FixedBitSet; import org.mockito.stubbing.OngoingStubbing; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; import java.util.Arrays; @@ -26,16 +26,17 @@ public class ByteVectorIdsKNNIteratorTests extends TestCase { @SneakyThrows - public void testNextDoc_whenCalled_thenIterateAllDocs() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1f, 2f, 3f }; final int[] filterIds = { 1, 2, 3 }; final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); @@ -48,15 +49,16 @@ public void testNextDoc_whenCalled_thenIterateAllDocs() { ByteVectorIdsKNNIterator iterator = new ByteVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); - assertEquals(expectedScores.get(i), iterator.score()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); } assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } @SneakyThrows public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOException { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; final List dataVectors = Arrays.asList( new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, @@ -65,10 +67,10 @@ public void testNextDoc_whenCalled_thenIterateAllDocsWithoutFilter() throws IOEx new byte[] { 23, 24, 25 } ); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn( dataVectors.get(0), dataVectors.get(1), diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java new file mode 100644 index 000000000..a39a3b2e9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class NestedBinaryVectorIdsKNNIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } + + @SneakyThrows + public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + when(values.nextDoc()).thenReturn(0, 2, 3, Integer.MAX_VALUE); + + // Execute and verify + NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator(queryVector, values, spaceType, parentBitSet); + assertEquals(0, iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(3, iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + verify(values, never()).advance(anyInt()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java index eff021234..08c859779 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.util.Arrays; import java.util.List; @@ -26,19 +26,20 @@ public class NestedByteVectorIdsKNNIteratorTests extends TestCase { @SneakyThrows public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; final int[] filterIds = { 0, 2, 3 }; // Parent id for 0 -> 1 // Parent id for 2, 3 -> 4 // In bit representation, it is 10010. In long, it is 18. final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); - final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); @@ -64,18 +65,19 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { @SneakyThrows public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerParent() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; // Parent id for 0 -> 1 // Parent id for 2, 3 -> 4 // In bit representation, it is 10010. In long, it is 18. final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); - final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); when(values.nextDoc()).thenReturn(0, 2, 3, Integer.MAX_VALUE); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java new file mode 100644 index 000000000..fe4dc7db9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonIndexMappingsBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchByteIT extends KNNRestTestCase { + @After + public void cleanUp() { + try { + deleteKNNIndex(INDEX_NAME); + } catch (Exception e) { + log.error(e); + } + } + + @SneakyThrows + public void testFilteredSearchWithFaissHnswByte_whenDoingApproximateSearch_thenReturnCorrectResults() { + validateFilteredSearchWithFaissHnswByte(INDEX_NAME, false); + } + + @SneakyThrows + public void testFilteredSearchWithFaissHnswByte_whenDoingExactSearch_thenReturnCorrectResults() { + validateFilteredSearchWithFaissHnswByte(INDEX_NAME, true); + } + + private void validateFilteredSearchWithFaissHnswByte(final String indexName, final boolean doExactSearch) throws Exception { + String filterFieldName = "parking"; + createKnnByteIndex(indexName, FIELD_NAME, 3, KNNEngine.FAISS); + + for (byte i = 1; i < 4; i++) { + addKnnDocWithAttributes( + indexName, + Integer.toString(i), + FIELD_NAME, + new float[] { i, i, i }, + ImmutableMap.of(filterFieldName, i % 2 == 1 ? "true" : "false") + ); + } + refreshIndex(indexName); + forceMergeKnnIndex(indexName); + + // Set it as 0 for approximate search and 100(larger than number of filtered id) for exact search + updateIndexSettings( + indexName, + Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, doExactSearch ? 100 : 0) + ); + + Float[] queryVector = { 3f, 3f, 3f }; + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(3) + .filterFieldName(filterFieldName) + .filterValue("true") + .build() + .getQueryString(); + Response response = searchKNNIndex(indexName, query, 3); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + private void createKnnByteIndex(final String indexName, final String fieldName, final int dimension, final KNNEngine knnEngine) + throws Exception { + KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() + .methodName(METHOD_HNSW) + .engine(knnEngine.getName()) + .build(); + + String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder() + .fieldName(fieldName) + .dimension(dimension) + .vectorDataType(VectorDataType.BYTE.getValue()) + .method(method) + .build() + .getIndexMapping(); + + createKnnIndex(indexName, knnIndexMapping); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java new file mode 100644 index 000000000..7985d08a7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonIndexMappingsBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.NestedKnnDocBuilder; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class NestedSearchByteIT extends KNNRestTestCase { + @After + public void cleanUp() { + try { + deleteKNNIndex(INDEX_NAME); + } catch (Exception e) { + log.error(e); + } + } + + @SneakyThrows + public void testNestedSearchWithFaissHnswByte_whenKIsTwo_thenReturnTwoResults() { + String nestedFieldName = "nested"; + createKnnByteIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 2, KNNEngine.FAISS); + + int totalDocCount = 15; + for (byte i = 0; i < totalDocCount; i++) { + String doc = NestedKnnDocBuilder.create(nestedFieldName) + .addVectors(FIELD_NAME, new Byte[] { i, i }, new Byte[] { i, i }) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + Byte[] queryVector = { 14, 14 }; + String query = KNNJsonQueryBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(2) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, 2); + String entity = EntityUtils.toString(response.getEntity()); + + assertEquals(2, parseHits(entity)); + assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertNotEquals("14", parseIds(entity).get(1)); + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 1, 1, 1 + * ], + * "k": 3, + * "filter": { + * "term": { + * "parking": "true" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testNestedSearchWithFaissHnswByte_whenDoingExactSearch_thenReturnCorrectResults() { + String nestedFieldName = "nested"; + String filterFieldName = "parking"; + createKnnByteIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 3, KNNEngine.FAISS); + + for (byte i = 1; i < 4; i++) { + String doc = NestedKnnDocBuilder.create(nestedFieldName) + .addVectors(FIELD_NAME, new Byte[] { i, i, i }, new Byte[] { i, i, i }, new Byte[] { i, i, i }) + .addTopLevelField(filterFieldName, i % 2 == 1 ? "true" : "false") + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Make it as an exact search by setting the threshold larger than size of filteredIds(6) + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 100)); + + Byte[] queryVector = { 3, 3, 3 }; + String query = KNNJsonQueryBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(3) + .filterFieldName(filterFieldName) + .filterValue("true") + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, 3); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + private void createKnnByteIndexWithNestedField( + final String indexName, + final String nestedFieldName, + final String fieldName, + final int dimension, + final KNNEngine knnEngine + ) throws Exception { + KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() + .methodName(METHOD_HNSW) + .engine(knnEngine.getName()) + .build(); + + String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder() + .nestedFieldName(nestedFieldName) + .fieldName(fieldName) + .dimension(dimension) + .vectorDataType(VectorDataType.BYTE.getValue()) + .method(method) + .build() + .getIndexMapping(); + + createKnnIndex(indexName, knnIndexMapping); + } +}