Skip to content

Commit

Permalink
Optimizes totalLiveDocs calculation for merge in
Browse files Browse the repository at this point in the history
NativeEnginesKNNVectorsWriter

Currently vector values are iterated irrespective of whether there are
deleted docs in the segment. This makes sure they aren't

Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas committed Sep 23, 2024
1 parent e348524 commit 61d3f71
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
Expand Down Expand Up @@ -90,7 +87,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

Expand All @@ -111,21 +108,16 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
vectorDataType,
fieldInfo,
mergeState
);
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
if (totalLiveDocs == 0) {
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(vectorDataType, fieldInfo, mergeState);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
final int totalLiveDocs = Math.toIntExact(knnVectorValues.totalLiveDocs());
if (totalLiveDocs <= 0) {
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
return;
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

writer.mergeIndex(knnVectorValues, totalLiveDocs);
Expand Down Expand Up @@ -181,72 +173,24 @@ public long ramBytesUsed() {
.sum();
}

/**
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
*
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
* @param mergeState The {@link MergeState} representing the state of the merge operation.
* @param <T> The type of vectors being processed.
* @return The {@link KNNVectorValues} associated with the field during the merge.
* @throws IOException If an I/O error occurs during the retrieval.
*/
private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
}
}

private QuantizationState train(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {
private QuantizationState train(final FieldInfo fieldInfo, final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier)
throws IOException {

final QuantizationService quantizationService = QuantizationService.getInstance();
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
if (quantizationParams != null) {
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
int totalLiveDocs = Math.toIntExact(knnVectorValues.totalLiveDocs());
if (totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
}

return quantizationState;
}

/**
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
* vectorsValues object which you plan to use later
*/
private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
// Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues,
// and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting
// the total live docs here.
int liveDocs = 0;
while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
liveDocs++;
}
return liveDocs;
}

private void initQuantizationStateWriterIfNecessary() throws IOException {
if (quantizationStateWriter == null) {
quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ public static <T, R> QuantizationService<T, R> getInstance() {
* @return The {@link QuantizationState} containing the state of the trained quantizer.
* @throws IOException If an I/O error occurs during the training process.
*/
public QuantizationState train(
final QuantizationParams quantizationParams,
final KNNVectorValues<T> knnVectorValues,
final long liveDocs
) throws IOException {
public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues<T> knnVectorValues)
throws IOException {
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);

// Create the training request from the vector values
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(
knnVectorValues,
knnVectorValues.totalLiveDocs()
);

// Train the quantizer and return the quantization state
return quantizer.train(trainingRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public byte[] getVector() throws IOException {
@Override
public byte[] conditionalCloneVector() throws IOException {
byte[] vector = getVector();
if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) {
if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeByteVectorValuesIterator
|| vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) {
return Arrays.copyOf(vector, vector.length);
}
return vector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public byte[] getVector() throws IOException {
@Override
public byte[] conditionalCloneVector() throws IOException {
byte[] vector = getVector();
if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) {
if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeByteVectorValuesIterator
|| vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) {
return Arrays.copyOf(vector, vector.length);

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public float[] getVector() throws IOException {
@Override
public float[] conditionalCloneVector() throws IOException {
float[] vector = getVector();
if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) {
if (vectorValuesIterator instanceof KNNVectorValuesIterator.MergeFloat32VectorValuesIterator
|| vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) {
return Arrays.copyOf(vector, vector.length);
}
return vector;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public final class KNNMergeVectorValues {

public static List<KNNVectorValuesSub<FloatVectorValues>> mergeFloatVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
}
final List<KNNVectorValuesSub<FloatVectorValues>> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
if (values != null) {
final Bits liveDocs = mergeState.liveDocs[i];
final int liveDocsInt;
if (liveDocs != null) {
if (liveDocs instanceof FixedBitSet) {
liveDocsInt = ((FixedBitSet) liveDocs).cardinality();
} else {
liveDocsInt = computeLiveDocs(values, liveDocs);
values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
}
} else {
liveDocsInt = Math.toIntExact(values.cost());
}
subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt));
}
}
}
return subs;
}

public static List<KNNVectorValuesSub<ByteVectorValues>> mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
throw new UnsupportedOperationException("Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
}
final List<KNNVectorValuesSub<ByteVectorValues>> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
if (values != null) {
final Bits liveDocs = mergeState.liveDocs[i];
final int liveDocsInt;
if (liveDocs != null) {
if (liveDocs instanceof FixedBitSet) {
liveDocsInt = ((FixedBitSet) liveDocs).cardinality();
} else {
liveDocsInt = computeLiveDocs(values, liveDocs);
values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
}
} else {
liveDocsInt = Math.toIntExact(values.cost());
}
subs.add(new KNNVectorValuesSub<>(mergeState.docMaps[i], values, liveDocsInt));
}
}
}
return subs;
}

private static int computeLiveDocs(final DocIdSetIterator values, Bits liveDocs) throws IOException {
int count = 0;
if (liveDocs != null) {
while (values.docID() != DocIdSetIterator.NO_MORE_DOCS) {
count++;
values.nextDoc();
}
}
return count;
}

static class KNNVectorValuesSub<T extends DocIdSetIterator> extends DocIDMerger.Sub {
final T values;
final int liveDocs;

KNNVectorValuesSub(MergeState.DocMap docMap, T values, int liveDocs) {
super(docMap);
this.values = values;
this.liveDocs = liveDocs;
}

@Override
public int nextDoc() throws IOException {
return values.nextDoc();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ public int bytesPerVector() {
* </pre>
* @return long
*/
@Deprecated
public long totalLiveDocs() {
return vectorValuesIterator.liveDocs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

package org.opensearch.knn.index.vectorvalues;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.common.FieldInfoExtractor;
Expand All @@ -17,9 +19,13 @@
import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValues.mergeByteVectorValues;
import static org.opensearch.knn.index.vectorvalues.KNNMergeVectorValues.mergeFloatVectorValues;

/**
* A factory class that provides various methods to create the {@link KNNVectorValues}.
*/
@Log4j2
public final class KNNVectorValuesFactory {

/**
Expand All @@ -45,7 +51,36 @@ public static <T> KNNVectorValues<T> getVectorValues(
final DocsWithFieldSet docIdWithFieldSet,
final Map<Integer, T> vectors
) {
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<T>(docIdWithFieldSet, vectors));
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<>(docIdWithFieldSet, vectors));
}

public static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
return getVectorValues(
vectorDataType,
new KNNVectorValuesIterator.MergeFloat32VectorValuesIterator(
mergeFloatVectorValues(fieldInfo, mergeState),
mergeState
)
);
case BYTE:
return getVectorValues(
vectorDataType,
new KNNVectorValuesIterator.MergeByteVectorValuesIterator(mergeByteVectorValues(fieldInfo, mergeState), mergeState)
);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
}
}

/**
Expand Down
Loading

0 comments on commit 61d3f71

Please sign in to comment.