Skip to content

Commit

Permalink
Revert "Makes sure KNNVectorValues aren't recreated unnecessarily whe…
Browse files Browse the repository at this point in the history
…n quantization isn't needed (opensearch-project#2137)"

This reverts commit 33c05ea.
  • Loading branch information
shatejas committed Sep 25, 2024
1 parent 5bad05c commit d6e713a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 59 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.17...2.x)
### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues;
Expand Down Expand Up @@ -83,19 +82,19 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
final FieldInfo fieldInfo = field.getFieldInfo();
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = field.getVectors().size();
int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()));
if (totalLiveDocs > 0) {
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
KNNVectorValues<?> knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());

final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());

StopWatch stopWatch = new StopWatch().start();

writer.flushIndex(knnVectorValues, totalLiveDocs);

long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
Expand All @@ -111,20 +110,17 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
vectorDataType,
fieldInfo,
mergeState
);
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState));
if (totalLiveDocs == 0) {
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
return;
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
KNNVectorValues<?> knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);

StopWatch stopWatch = new StopWatch().start();

Expand Down Expand Up @@ -195,36 +191,27 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
}

private QuantizationState train(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {
private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues<?> knnVectorValues, final int totalLiveDocs)
throws IOException {

final QuantizationService quantizationService = QuantizationService.getInstance();
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -177,11 +176,6 @@ public void testFlush() {
throw new RuntimeException(e);
}
});

knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size())
);
}
}

Expand Down Expand Up @@ -270,11 +264,6 @@ public void testFlush_WithQuantization() {
throw new RuntimeException(e);
}
});

knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size() * 2)
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -145,10 +144,6 @@ public void testMerge() {
if (!mergedVectors.isEmpty()) {
verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size());
assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L);
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues),
times(2)
);
} else {
verifyNoInteractions(nativeIndexWriter);
}
Expand Down Expand Up @@ -216,10 +211,6 @@ public void testMerge_WithQuantization() {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState);
verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size());
assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L);
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues),
times(3)
);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
verifyNoInteractions(nativeIndexWriter);
Expand Down

0 comments on commit d6e713a

Please sign in to comment.