Skip to content

Commit

Permalink
Add integration test for binary vector values
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
  • Loading branch information
VijayanB committed Sep 24, 2024
1 parent a0eacec commit e7cddc9
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ public KNN80DocValuesProducer(DocValuesProducer delegate, SegmentReadState state
continue;
}
List<String> engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), field.name, state.segmentInfo);
if (engineFiles.isEmpty()) {
continue;
}
Path indexPath = PathUtils.get(directoryPath, engineFiles.get(0));
indexPathMap.putIfAbsent(field.getName(), indexPath.toString());
}
Expand Down
75 changes: 70 additions & 5 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.integ;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
Expand All @@ -13,11 +14,13 @@
import org.junit.After;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNJsonIndexMappingsBuilder;
import org.opensearch.knn.KNNJsonQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;

Expand All @@ -36,6 +39,8 @@
@Log4j2
public class BinaryIndexIT extends KNNRestTestCase {
private static TestUtils.TestData testData;
private static final int NEVER_BUILD_GRAPH = -1;
private static final int ALWAYS_BUILD_GRAPH = 0;

@BeforeClass
public static void setUpClass() throws IOException {
Expand Down Expand Up @@ -104,6 +109,52 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() {
}
}

@SneakyThrows
public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH);
ingestTestData(INDEX_NAME, FIELD_NAME);

assertEquals(0, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size());

// update build vector data structure setting
updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0));
forceMergeKnnIndex(INDEX_NAME, 1);

int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.1);
}
}

@SneakyThrows
public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length);
ingestTestData(INDEX_NAME, FIELD_NAME, false);

assertEquals(0, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size());

// update build vector data structure setting
updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0));
forceMergeKnnIndex(INDEX_NAME, 1);

int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.1);
}
}

@SneakyThrows
public void testFaissHnswBinary_whenRadialSearch_thenThrowException() {
// Create Index
Expand Down Expand Up @@ -157,13 +208,18 @@ private List<KNNResult> runKnnQuery(final String indexName, final String fieldNa
}

private void ingestTestData(final String indexName, final String fieldName) throws Exception {
ingestTestData(indexName, fieldName, true);
}

private void ingestTestData(final String indexName, final String fieldName, boolean refresh) throws Exception {
// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
ImmutableList.of(fieldName),
ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()),
refresh
);
}

Expand All @@ -172,8 +228,13 @@ private void ingestTestData(final String indexName, final String fieldName) thro
assertEquals(testData.indexData.docs.length, getDocCount(indexName));
}

private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension)
throws IOException {
private void createKnnHnswBinaryIndex(
final KNNEngine knnEngine,
final String indexName,
final String fieldName,
final int dimension,
final int threshold
) throws IOException {
KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder()
.methodName(METHOD_HNSW)
.engine(knnEngine.getName())
Expand All @@ -186,7 +247,11 @@ private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String in
.method(method)
.build()
.getIndexMapping();
createKnnIndex(indexName, buildKNNIndexSettings(threshold), knnIndexMapping);
}

createKnnIndex(indexName, knnIndexMapping);
private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension)
throws IOException {
createKnnHnswBinaryIndex(knnEngine, indexName, fieldName, dimension, ALWAYS_BUILD_GRAPH);
}
}

0 comments on commit e7cddc9

Please sign in to comment.