diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 23cd2a4de..56efcfc5e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -15,13 +15,10 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; @@ -90,7 +87,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getDocsWithField(), field.getVectors() ); - final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier); final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -111,21 +108,16 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState flatVectorsWriter.mergeOneField(fieldInfo, mergeState); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - final Supplier> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge( - vectorDataType, - fieldInfo, - mergeState - ); - int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get()); - if (totalLiveDocs == 0) { + final Supplier> knnVectorValuesSupplier = () -> getVectorValues(vectorDataType, fieldInfo, mergeState); + final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + final int totalLiveDocs = Math.toIntExact(knnVectorValues.totalLiveDocs()); + if (totalLiveDocs <= 0) { log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); return; } - final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - StopWatch stopWatch = new StopWatch().start(); writer.mergeIndex(knnVectorValues, totalLiveDocs); @@ -181,72 +173,24 @@ public long ramBytesUsed() { .sum(); } - /** - * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param mergeState The {@link MergeState} representing the state of the merge operation. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field during the merge. - * @throws IOException If an I/O error occurs during the retrieval. - */ - private KNNVectorValues getKNNVectorValuesForMerge( - final VectorDataType vectorDataType, - final FieldInfo fieldInfo, - final MergeState mergeState - ) { - try { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedFloats); - case BYTE: - ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedBytes); - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } - } catch (final IOException e) { - log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); - throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); - } - } - - private QuantizationState train( - final FieldInfo fieldInfo, - final Supplier> knnVectorValuesSupplier, - final int totalLiveDocs - ) throws IOException { + private QuantizationState train(final FieldInfo fieldInfo, final Supplier> knnVectorValuesSupplier) + throws IOException { final QuantizationService quantizationService = QuantizationService.getInstance(); final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - if (quantizationParams != null && totalLiveDocs > 0) { - initQuantizationStateWriterIfNecessary(); - KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); - quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + if (quantizationParams != null) { + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + int totalLiveDocs = Math.toIntExact(knnVectorValues.totalLiveDocs()); + if (totalLiveDocs > 0) { + initQuantizationStateWriterIfNecessary(); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + } } - return quantizationState; } - /** - * The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the - * vectorsValues object which you plan to use later - */ - private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { - // Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues, - // and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting - // the total live docs here. - int liveDocs = 0; - while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - liveDocs++; - } - return liveDocs; - } - private void initQuantizationStateWriterIfNecessary() throws IOException { if (quantizationStateWriter == null) { quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index 771848730..4c8e2c211 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -57,15 +57,15 @@ public static QuantizationService getInstance() { * @return The {@link QuantizationState} containing the state of the trained quantizer. * @throws IOException If an I/O error occurs during the training process. */ - public QuantizationState train( - final QuantizationParams quantizationParams, - final KNNVectorValues knnVectorValues, - final long liveDocs - ) throws IOException { + public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues) + throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); // Create the training request from the vector values - KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>( + knnVectorValues, + knnVectorValues.totalLiveDocs() + ); // Train the quantizer and return the quantization state return quantizer.train(trainingRequest); diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index 5da093fd5..3e044f733 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -34,7 +34,8 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeByteVectorValuesIterator + || vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index 1ebc50970..67efc2b9f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -34,7 +34,8 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeByteVectorValuesIterator + || vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { return Arrays.copyOf(vector, vector.length); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index dffdd8f0d..156792424 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -32,7 +32,8 @@ public float[] getVector() throws IOException { @Override public float[] conditionalCloneVector() throws IOException { float[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { + if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeFloat32VectorValuesIterator + || vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValues.java new file mode 100644 index 000000000..ca33b16ed --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNMergeVectorValues.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public final class KNNMergeVectorValues { + + public static List> mergeFloatVectorValues(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + assert fieldInfo != null && fieldInfo.hasVectorValues(); + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32"); + } + final List> subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader != null) { + FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name); + if (values != null) { + final Bits liveDocs = mergeState.liveDocs[i]; + final int liveDocsInt; + if (liveDocs != null) { + if (liveDocs instanceof FixedBitSet) { + liveDocsInt = ((FixedBitSet) liveDocs).cardinality(); + } else { + liveDocsInt = computeLiveDocs(values, liveDocs); + values = knnVectorsReader.getFloatVectorValues(fieldInfo.name); + } + } else { + liveDocsInt = Math.toIntExact(values.cost()); + } + subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt)); + } + } + } + return subs; + } + + public static List> mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + assert fieldInfo != null && fieldInfo.hasVectorValues(); + if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) { + throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE"); + } + final List> subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader != null) { + ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name); + if (values != null) { + final Bits liveDocs = mergeState.liveDocs[i]; + final int liveDocsInt; + if (liveDocs != null) { + if (liveDocs instanceof FixedBitSet) { + liveDocsInt = ((FixedBitSet) liveDocs).cardinality(); + } else { + liveDocsInt = computeLiveDocs(values, liveDocs); + values = knnVectorsReader.getByteVectorValues(fieldInfo.name); + } + } else { + liveDocsInt = Math.toIntExact(values.cost()); + } + subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt)); + } + } + } + return subs; + } + + private static int computeLiveDocs(final DocIdSetIterator values, Bits liveDocs) throws IOException { + int count = 0; + if (liveDocs != null) { + while (values.docID() != DocIdSetIterator.NO_MORE_DOCS) { + count++; + values.nextDoc(); + } + } + return count; + } + + static class KNNVectorValuesSub extends DocIDMerger.Sub { + final T values; + final int liveDocs; + + KNNVectorValuesSub(MergeState.DocMap docMap, T values, int liveDocs) { + super(docMap); + this.values = values; + this.liveDocs = liveDocs; + } + + @Override + public int nextDoc() throws IOException { + return values.nextDoc(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index b12395185..a7cbc8ed7 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -82,7 +82,6 @@ public int bytesPerVector() { * * @return long */ - @Deprecated public long totalLiveDocs() { return vectorValuesIterator.liveDocs(); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e217..29f07741a 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,10 +5,12 @@ package org.opensearch.knn.index.vectorvalues; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.MergeState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; @@ -17,9 +19,13 @@ import java.io.IOException; import java.util.Map; +import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValues.mergeByteVectorValues; +import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValues.mergeFloatVectorValues; + /** * A factory class that provides various methods to create the {@link KNNVectorValues}. */ +@Log4j2 public final class KNNVectorValuesFactory { /** @@ -45,7 +51,36 @@ public static KNNVectorValues getVectorValues( final DocsWithFieldSet docIdWithFieldSet, final Map vectors ) { - return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues(docIdWithFieldSet, vectors)); + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<>(docIdWithFieldSet, vectors)); + } + + public static KNNVectorValues getVectorValues( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) { + try { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + return getVectorValues( + vectorDataType, + new KNNVectorValuesIterator.MergeFloat32VectorValuesIterator( + mergeFloatVectorValues(fieldInfo, mergeState), + mergeState + ) + ); + case BYTE: + return getVectorValues( + vectorDataType, + new KNNVectorValuesIterator.MergeByteVectorValuesIterator(mergeByteVectorValues(fieldInfo, mergeState), mergeState) + ); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } catch (final IOException e) { + log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); + throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java index 4f1445c1c..cb618ba72 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java @@ -9,8 +9,10 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.index.codec.util.KNNCodecUtil; @@ -19,6 +21,8 @@ import java.util.Map; import java.util.function.Function; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + /** * An abstract class that provides an iterator to iterate over KNNVectors, as KNNVectors are stored as different * representation like {@link BinaryDocValues}, {@link FloatVectorValues}, FieldWriter etc. How to iterate using this @@ -185,4 +189,96 @@ public VectorValueExtractorStrategy getVectorExtractorStrategy() { } } + abstract class MergeSegmentVectorValuesIterator implements KNNVectorValuesIterator { + + private DocIDMerger> docIdMerger; + private final int liveDocs; + private int docId; + protected KNNMergeVectorValues.KNNVectorValuesSub current; + + private static final VectorValueExtractorStrategy VECTOR_VALUES_STRATEGY = + new VectorValueExtractorStrategy.MergeSegmentValuesExtractor(); + + MergeSegmentVectorValuesIterator(final List> subs, final MergeState mergeState) + throws IOException { + this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (KNNMergeVectorValues.KNNVectorValuesSub sub : subs) { + totalSize += sub.liveDocs; + } + this.liveDocs = totalSize; + this.docId = -1; + + } + + @Override + public int docId() { + return docId; + } + + @Override + public int advance(int docId) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + } + return docId; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + // while we can get the values of current, this method is intended to be called once so it's much better to throw + // so Liskov-Substitution-principle is not violated unknowingly + throw new UnsupportedOperationException(); + } + + @Override + public long liveDocs() { + return liveDocs; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return VECTOR_VALUES_STRATEGY; + } + + public abstract U vectorValue() throws IOException; + } + + class MergeFloat32VectorValuesIterator extends MergeSegmentVectorValuesIterator { + + MergeFloat32VectorValuesIterator( + final List> subs, + final MergeState mergeState + ) throws IOException { + super(subs, mergeState); + } + + @Override + public float[] vectorValue() throws IOException { + return current.values.vectorValue(); + } + } + + class MergeByteVectorValuesIterator extends MergeSegmentVectorValuesIterator { + + MergeByteVectorValuesIterator( + final List> subs, + final MergeState mergeState + ) throws IOException { + super(subs, mergeState); + } + + @Override + public byte[] vectorValue() throws IOException { + return current.values.vectorValue(); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f6..cc70468ca 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -123,4 +123,23 @@ public T extract(final VectorDataType vectorDataType, final KNNVectorValuesI } } + /** + * Strategy to extract the vector from {@link KNNVectorValuesIterator.MergeSegmentVectorValuesIterator} + */ + class MergeSegmentValuesExtractor implements VectorValueExtractorStrategy { + @Override + public T extract(final VectorDataType vectorDataType, final KNNVectorValuesIterator vectorValuesIterator) throws IOException { + switch (vectorDataType) { + case FLOAT: + return (T) ((KNNVectorValuesIterator.MergeFloat32VectorValuesIterator) vectorValuesIterator).vectorValue(); + case BYTE: + case BINARY: + return (T) ((KNNVectorValuesIterator.MergeByteVectorValuesIterator) vectorValuesIterator).vectorValue(); + } + throw new IllegalArgumentException( + "Valid Vector data type not passed to extract vector from FieldWriterIteratorVectorExtractor strategy" + ); + } + } + } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index ad72f5b24..934e97ea5 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -44,6 +44,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -176,6 +177,10 @@ public void testFlush() { throw new RuntimeException(e); } }); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size()) + ); } } @@ -230,8 +235,7 @@ public void testFlush_WithQuantization() { when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i))).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -264,6 +268,11 @@ public void testFlush_WithQuantization() { throw new RuntimeException(e); } }); + + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size() * 2) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 440e8bbc5..286513420 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -8,7 +8,6 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; -import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -45,6 +44,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -104,9 +104,6 @@ public void testMerge() { MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); - MockedStatic mergedVectorValuesMockedStatic = mockStatic( - KnnVectorsWriter.MergedVectorValues.class - ); MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( KNN990QuantizationStateWriter.class ); @@ -122,10 +119,9 @@ public void testMerge() { fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) .thenReturn(field); - mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) - .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) @@ -144,6 +140,9 @@ public void testMerge() { if (!mergedVectors.isEmpty()) { verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ); } else { verifyNoInteractions(nativeIndexWriter); } @@ -166,9 +165,6 @@ public void testMerge_WithQuantization() { MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( KNN990QuantizationStateWriter.class ); - MockedStatic mergedVectorValuesMockedStatic = mockStatic( - KnnVectorsWriter.MergedVectorValues.class - ); ) { quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); @@ -181,15 +177,13 @@ public void testMerge_WithQuantization() { NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) .thenReturn(field); - - mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) - .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + when(quantizationService.train(quantizationParams, knnVectorValues)).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -211,6 +205,10 @@ public void testMerge_WithQuantization() { verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, fieldInfo, mergeState), + times(2) + ); } else { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); verifyNoInteractions(nativeIndexWriter); diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 690391dbd..720b67fd5 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -46,7 +46,7 @@ public void setUp() throws Exception { public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +62,7 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +85,7 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -110,7 +110,7 @@ public void testTrain_fourBitQuantizer_success() throws IOException { public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +125,7 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +138,7 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -152,7 +152,7 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); }