Skip to content

Commit

Permalink
Integrate KNNVectorValues with vector ANN Search flow
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 13, 2024
1 parent f5ba771 commit 12917ba
Show file tree
Hide file tree
Showing 16 changed files with 246 additions and 77 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936)
* Integrate KNNVectorValues with vector ANN Search flow [#1952](https://github.com/opensearch-project/k-NN/pull/1952)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
### Infrastructure
### Documentation
### Maintenance
Expand Down
38 changes: 38 additions & 0 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import lombok.experimental.UtilityClass;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

/**
* A utility class to extract information from FieldInfo.
*/
@UtilityClass
public class FieldInfoExtractor {

/**
* Extract vector data type from fieldInfo
* @param fieldInfo {@link FieldInfo}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
if (StringUtils.isEmpty(vectorDataTypeString)) {
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata != null) {
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
}
}
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) :
VectorDataType.DEFAULT;
}
}
22 changes: 15 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
Expand Down Expand Up @@ -43,6 +41,10 @@
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -412,25 +414,31 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
private KNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader());
return knnQuery.getParentsFilter() == null
? new FilteredIdsKNNByteIterator(filterIdsBitSet, knnQuery.getByteQueryVector(), values, spaceType)
? new FilteredIdsKNNByteIterator(
filterIdsBitSet,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
)
: new NestedFilteredIdsKNNByteIterator(
filterIdsBitSet,
knnQuery.getByteQueryVector(),
values,
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader());
return knnQuery.getParentsFilter() == null
? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType)
? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType)
: new NestedFilteredIdsKNNIterator(
filterIdsBitSet,
knnQuery.getQueryVector(),
values,
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
Expand All @@ -26,21 +24,21 @@ public class FilteredIdsKNNByteIterator implements KNNIterator {
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final BinaryDocValues binaryDocValues;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNByteIterator(
final BitSet filterIdsBitSet,
final byte[] queryVector,
final BinaryDocValues binaryDocValues,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryDocValues = binaryDocValues;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}
Expand All @@ -57,7 +55,7 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryDocValues.advance(docId);
int doc = binaryVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
Expand All @@ -69,9 +67,7 @@ public float score() {
}

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final byte[] vector = byteStream.readAllBytes();
final byte[] vector = binaryVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

import java.io.IOException;

Expand All @@ -27,21 +24,21 @@ public class FilteredIdsKNNIterator implements KNNIterator {
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final float[] queryVector;
protected final BinaryDocValues binaryDocValues;
protected final KNNFloatVectorValues knnFloatVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNIterator(
final BitSet filterIdsBitSet,
final float[] queryVector,
final BinaryDocValues binaryDocValues,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryDocValues = binaryDocValues;
this.knnFloatVectorValues = knnFloatVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}
Expand All @@ -58,7 +55,7 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryDocValues.advance(docId);
int doc = knnFloatVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
Expand All @@ -70,9 +67,7 @@ public float score() {
}

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(value);
final float[] vector = vectorSerializer.byteToFloatArray(value);
final float[] vector = knnFloatVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

Expand All @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator
public NestedFilteredIdsKNNByteIterator(
final BitSet filterIdsArray,
final byte[] queryVector,
final BinaryDocValues values,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

Expand All @@ -47,7 +47,7 @@ public int nextDoc() throws IOException {
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryDocValues.advance(docId);
binaryVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

import java.io.IOException;

Expand All @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator {
public NestedFilteredIdsKNNIterator(
final BitSet filterIdsArray,
final float[] queryVector,
final BinaryDocValues values,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

Expand All @@ -47,7 +47,7 @@ public int nextDoc() throws IOException {
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryDocValues.advance(docId);
knnFloatVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;

import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -21,7 +27,7 @@ public final class KNNVectorValuesFactory {
*
* @param vectorDataType {@link VectorDataType}
* @param docIdSetIterator {@link DocIdSetIterator}
* @return {@link KNNVectorValues} of type float[]
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) {
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator));
Expand All @@ -32,7 +38,7 @@ public static <T> KNNVectorValues<T> getVectorValues(final VectorDataType vector
*
* @param vectorDataType {@link VectorDataType}
* @param docIdWithFieldSet {@link DocsWithFieldSet}
* @return {@link KNNVectorValues} of type float[]
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
Expand All @@ -42,6 +48,30 @@ public static <T> KNNVectorValues<T> getVectorValues(
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<T>(docIdWithFieldSet, vectors));
}

/**
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}
*
* @param fieldInfo {@link FieldInfo}
* @param leafReader {@link LeafReader}
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo, final LeafReader leafReader) throws IOException {
DocIdSetIterator docIdSetIterator;
if (fieldInfo.getVectorDimension() > 0) {
if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) {
docIdSetIterator = leafReader.getByteVectorValues(fieldInfo.getName());
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
docIdSetIterator = leafReader.getFloatVectorValues(fieldInfo.getName());
} else {
throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues");
}
} else {
docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName());
}
KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator);
return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator);
}

@SuppressWarnings("unchecked")
private static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

package org.opensearch.knn.indices;

import org.apache.commons.lang.StringUtils;

import java.util.Locale;

/**
* A utility class for models.
*/
Expand All @@ -33,4 +37,21 @@ public static boolean isModelCreated(ModelMetadata modelMetadata) {
return modelMetadata.getState().equals(ModelState.CREATED);
}

/**
* Gets Model Metadata from a given model id.
* @param modelId {@link String}
* @return {@link ModelMetadata}
*/
public static ModelMetadata getModelMetadata(final String modelId) {
if (StringUtils.isEmpty(modelId)) {
return null;
}
final Model model = ModelCache.getInstance().get(modelId);
final ModelMetadata modelMetadata = model.getModelMetadata();
if (ModelUtil.isModelCreated(modelMetadata) == false) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}

}
Loading

0 comments on commit 12917ba

Please sign in to comment.