diff --git a/CHANGELOG.md b/CHANGELOG.md index 349125beb..398c579e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) * k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984) * Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823) +* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002) ### Enhancements * Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950) ### Bug Fixes diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 574efb6fd..d6375653d 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -36,6 +36,12 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create a index with ids and byte vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index @@ -110,6 +116,13 @@ namespace knn_jni { jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty byte index defined by the values in the Java map, parametersJ. Train the index with + // the byte vectors located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 09f3ec8b7..d42ce197c 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -112,6 +112,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createByteIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -216,6 +224,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainByteIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 0e1029ecf..ba15c3ce7 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -320,6 +320,96 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter faiss::write_index_binary(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); + + // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike. + // Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255 + int batchSize = 1000; + std::vector inputFloatVectors(batchSize * dim); + std::vector floatVectorsIds(batchSize); + int id = 0; + auto iter = inputVectors->begin(); + + for (int id = 0; id < numVectors; id += batchSize) { + if (numVectors - id < batchSize) { + batchSize = numVectors - id; + } + + for (int i = 0; i < batchSize; ++i) { + floatVectorsIds[i] = ids[id + i]; + for (int j = 0; j < dim; ++j, ++iter) { + inputFloatVectors[i * dim + j] = static_cast(*iter); + } + } + idMap.add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data()); + } + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index(&idMap, indexPathCpp.c_str()); +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -782,6 +872,73 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + + auto iter = trainingVectorsPointerCpp->begin(); + std::vector trainingFloatVectors(numVectors * dimensionJ); + for(int i=0; i < numVectors * dimensionJ; ++i, ++iter) { + trainingFloatVectors[i] = static_cast(*iter); + } + + if(!indexWriter->is_trained) { + InternalTrainIndex(indexWriter.get(), numVectors, trainingFloatVectors.data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index bcdc4f18b..70c986b7d 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -192,6 +192,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -335,6 +350,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainByteIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index a1839c6ce..5f6f83c46 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -230,6 +230,55 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::remove(indexPath.c_str()); } +TEST(FaissCreateByteIndexFromTemplateTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 100; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 8; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors->push_back(test_util::RandomInt(-128, 127)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,SQ8_direct_signed"; + + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( + jniEnv, reinterpret_cast(&vectors))) + .WillRepeatedly(Return(vectors->size())); + + std::string spaceType = knn_jni::L2; + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + + knn_jni::faiss_wrapper::CreateByteIndexFromTemplate( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong)vectors, dim, (jstring)&indexPath, + reinterpret_cast(&(vectorIoWriter.data)), + (jobject) ¶metersMap + ); + + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + + // Clean up + std::remove(indexPath.c_str()); +} + TEST(FaissLoadIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 100; @@ -717,6 +766,38 @@ TEST(FaissTrainIndexTest, BasicAssertions) { ASSERT_TRUE(trainedIndex->is_trained); } +TEST(FaissTrainByteIndexTest, BasicAssertions) { + // Define the index configuration + int dim = 2; + std::string spaceType = knn_jni::L2; + std::string index_description = "IVF4,SQ8_direct_signed"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &index_description; + + // Define training data + int numTrainingVectors = 256; + std::vector trainingVectors = test_util::RandomByteVectors(dim, numTrainingVectors, -128, 127); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Perform training + std::unique_ptr> trainedIndexSerialization( + reinterpret_cast *>( + knn_jni::faiss_wrapper::TrainByteIndex( + &mockJNIUtil, jniEnv, (jobject) ¶metersMap, dim, + reinterpret_cast(&trainingVectors)))); + + std::unique_ptr trainedIndex( + test_util::FaissLoadFromSerializedIndex(trainedIndexSerialization.get())); + + // Confirm that training succeeded + ASSERT_TRUE(trainedIndex->is_trained); +} + TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 4f8bd2c34..47d1a7c8e 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -447,6 +447,14 @@ std::vector test_util::RandomVectors(int dim, int64_t numVectors, float m return vectors; } +std::vector test_util::RandomByteVectors(int dim, int64_t numVectors, int min, int max) { + std::vector vectors(dim*numVectors); + for (int64_t i = 0; i < dim*numVectors; i++) { + vectors[i] = test_util::RandomInt(min, max); + } + return vectors; +} + std::vector test_util::Range(int64_t numElements) { std::vector rangeVector(numElements); for (int64_t i = 0; i < numElements; i++) { diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index a90d45dd9..ea02da6f2 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -173,6 +173,8 @@ namespace test_util { std::vector RandomVectors(int dim, int64_t numVectors, float min, float max); + std::vector RandomByteVectors(int dim, int64_t numVectors, int min, int max); + std::vector Range(int64_t numElements); // returns the number of 64 bit words it would take to hold numBits diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 9283e5ee6..4827a4582 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -14,6 +14,12 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.training.BinaryTrainingDataConsumer; +import org.opensearch.knn.training.ByteTrainingDataConsumer; +import org.opensearch.knn.training.FloatTrainingDataConsumer; +import org.opensearch.knn.training.TrainingDataConsumer; import java.util.Arrays; import java.util.Locale; @@ -48,6 +54,16 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) { } return vector; } + + @Override + public TrainingDataConsumer getTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + return new BinaryTrainingDataConsumer(trainingDataAllocation); + } + + @Override + public void freeNativeMemory(long memoryAddress) { + JNICommons.freeBinaryVectorData(memoryAddress); + } }, BYTE("byte") { @@ -67,6 +83,16 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) { } return vector; } + + @Override + public TrainingDataConsumer getTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + return new ByteTrainingDataConsumer(trainingDataAllocation); + } + + @Override + public void freeNativeMemory(long memoryAddress) { + JNICommons.freeByteVectorData(memoryAddress); + } }, FLOAT("float") { @@ -81,6 +107,16 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) { return vectorSerializer.byteToFloatArray(binaryValue); } + @Override + public TrainingDataConsumer getTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + return new FloatTrainingDataConsumer(trainingDataAllocation); + } + + @Override + public void freeNativeMemory(long memoryAddress) { + JNICommons.freeVectorData(memoryAddress); + } + }; public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) @@ -107,6 +143,17 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) { */ public abstract float[] getVectorFromBytesRef(BytesRef binaryValue); + /** + * @param trainingDataAllocation training data that has been allocated in native memory + * @return TrainingDataConsumer which consumes training data + */ + public abstract TrainingDataConsumer getTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation); + + /** + * @param memoryAddress address to be freed + */ + public abstract void freeNativeMemory(long memoryAddress); + /** * Validates if given VectorDataType is in the list of supported data types. * @param vectorDataType VectorDataType diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java index 7c3860223..ffa12a231 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -28,14 +28,11 @@ public void deallocate() { @Override protected long transfer(List batch, boolean append) throws IOException { - if (!batch.isEmpty()) { - return JNICommons.storeBinaryVectorData( - getVectorAddress(), - batch.toArray(new byte[][] {}), - (long) batch.get(0).length * transferLimit, - append - ); - } - return 0; + return JNICommons.storeBinaryVectorData( + getVectorAddress(), + batch.toArray(new byte[][] {}), + (long) batch.get(0).length * transferLimit, + append + ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java index 91dd36ba5..83ebf2fa3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -23,15 +23,12 @@ public OffHeapByteVectorTransfer(int transferLimit) { @Override protected long transfer(List batch, boolean append) throws IOException { - if (!batch.isEmpty()) { - return JNICommons.storeByteVectorData( - getVectorAddress(), - batch.toArray(new byte[][] {}), - (long) batch.get(0).length * transferLimit, - append - ); - } - return 0; + return JNICommons.storeByteVectorData( + getVectorAddress(), + batch.toArray(new byte[][] {}), + (long) batch.get(0).length * transferLimit, + append + ); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java index 99d8c99a5..0eb28d791 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -21,15 +21,12 @@ public OffHeapFloatVectorTransfer(int transferLimit) { @Override protected long transfer(final List vectorsToTransfer, boolean append) throws IOException { - if (!vectorsToTransfer.isEmpty()) { - return JNICommons.storeVectorData( - getVectorAddress(), - vectorsToTransfer.toArray(new float[][] {}), - (long) vectorsToTransfer.get(0).length * this.transferLimit, - append - ); - } - return 0; + return JNICommons.storeVectorData( + getVectorAddress(), + vectorsToTransfer.toArray(new float[][] {}), + (long) vectorsToTransfer.get(0).length * this.transferLimit, + append + ); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index b3dd12c92..70ab4222b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -38,7 +38,11 @@ */ public class FaissIVFMethod extends AbstractFaissMethod { - private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BINARY); + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BINARY, + VectorDataType.BYTE + ); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 755b6b925..c711f3342 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -13,10 +13,8 @@ import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; -import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.KNNWeight; -import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.watcher.FileWatcher; @@ -299,11 +297,7 @@ private void cleanup() { closed = true; if (this.memoryAddress != 0) { - if (IndexUtil.isBinaryIndex(vectorDataType)) { - JNICommons.freeBinaryVectorData(this.memoryAddress); - } else { - JNICommons.freeVectorData(this.memoryAddress); - } + vectorDataType.freeNativeMemory(this.memoryAddress); } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 6723c2ed0..8324f2340 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -14,11 +14,8 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.training.ByteTrainingDataConsumer; -import org.opensearch.knn.training.FloatTrainingDataConsumer; import org.opensearch.knn.training.TrainingDataConsumer; import org.opensearch.knn.training.VectorReader; import org.opensearch.watcher.FileChangesListener; @@ -174,9 +171,8 @@ public NativeMemoryAllocation.TrainingDataAllocation load( nativeMemoryEntryContext.getVectorDataType() ); - TrainingDataConsumer vectorDataConsumer = nativeMemoryEntryContext.getVectorDataType() == VectorDataType.FLOAT - ? new FloatTrainingDataConsumer(trainingDataAllocation) - : new ByteTrainingDataConsumer(trainingDataAllocation); + TrainingDataConsumer vectorDataConsumer = nativeMemoryEntryContext.getVectorDataType() + .getTrainingDataConsumer(trainingDataAllocation); trainingDataAllocation.writeLock(); diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 431579fae..e808c4963 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -31,6 +31,7 @@ import java.util.HashMap; import java.util.Locale; import java.util.Map; +import java.util.Set; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; @@ -52,6 +53,7 @@ public class IndexUtil { private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); + public static final Set VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE); /** * Determines the size of a file on disk in kilobytes @@ -145,16 +147,6 @@ public static ValidationException validateKnnField( } if (trainRequestVectorDataType != null) { - if (VectorDataType.BYTE == trainRequestVectorDataType) { - exception.addValidationError( - String.format( - Locale.ROOT, - "vector data type \"%s\" is not supported for training.", - trainRequestVectorDataType.getValue() - ) - ); - return exception; - } VectorDataType trainIndexDataType = getVectorDataTypeFromFieldMapping(fieldMap); if (trainIndexDataType != trainRequestVectorDataType) { @@ -170,20 +162,18 @@ public static ValidationException validateKnnField( return exception; } - // Block binary vector data type for pq encoder + // Block binary and byte vector data type for any encoder if (trainRequestKnnMethodContext != null) { MethodComponentContext methodComponentContext = trainRequestKnnMethodContext.getMethodComponentContext(); Map parameters = methodComponentContext.getParameters(); if (parameters != null && parameters.containsKey(KNNConstants.METHOD_ENCODER_PARAMETER)) { MethodComponentContext encoder = (MethodComponentContext) parameters.get(KNNConstants.METHOD_ENCODER_PARAMETER); - if (encoder != null - && KNNConstants.ENCODER_PQ.equals(encoder.getName()) - && VectorDataType.BINARY == trainRequestVectorDataType) { + if (encoder != null && VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS.contains(trainRequestVectorDataType)) { exception.addValidationError( String.format( Locale.ROOT, - "vector data type \"%s\" is not supported for pq encoder.", + "encoder is not supported for vector data type [%s]", trainRequestVectorDataType.getValue() ) ); @@ -325,16 +315,6 @@ public static boolean isBinaryIndex(KNNEngine knnEngine, Map par && parameters.get(VECTOR_DATA_TYPE_FIELD).toString().equals(VectorDataType.BINARY.getValue()); } - /** - * Tell if it is binary index or not - * - * @param vectorDataType vector data type - * @return true if it is binary index - */ - public static boolean isBinaryIndex(VectorDataType vectorDataType) { - return VectorDataType.BINARY == vectorDataType; - } - /** * Update vector data type into parameters * @@ -345,6 +325,9 @@ public static void updateVectorDataTypeToParameters(Map paramete if (VectorDataType.BINARY == vectorDataType) { parameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); } + if (VectorDataType.BYTE == vectorDataType) { + parameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } } /** diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 26c703eeb..037171b98 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -186,6 +186,25 @@ public static native void createBinaryIndexFromTemplate( Map parameters ); + /** + * Create a byte index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createByteIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -349,6 +368,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty byte index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainByteIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index df1024db4..df3e551cd 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -178,12 +178,12 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i public static native void freeBinaryVectorData(long memoryAddress); /** - * Free up the memory allocated for the binary data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long)} + * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long)} * *

- * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. *

* * @param memoryAddress address to be freed. diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 1177d635e..94c1ec48e 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -169,10 +169,15 @@ public static void createIndexFromTemplate( if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); return; - } else { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + } + if (IndexUtil.isByteIndex(parameters)) { + FaissService.createByteIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); return; } + + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } throw new IllegalArgumentException( @@ -405,6 +410,9 @@ public static byte[] trainIndex(Map indexParameters, int dimensi if (IndexUtil.isBinaryIndex(knnEngine, indexParameters)) { return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); } + if (IndexUtil.isByteIndex(indexParameters)) { + return FaissService.trainByteIndex(indexParameters, dimension, trainVectorsPointer); + } return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/training/BinaryTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/BinaryTrainingDataConsumer.java new file mode 100644 index 000000000..db953b1c5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/BinaryTrainingDataConsumer.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.training; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.search.SearchHit; + +import java.util.ArrayList; +import java.util.List; + +/** + * Transfers binary vectors from JVM to native memory. + */ +public class BinaryTrainingDataConsumer extends TrainingDataConsumer { + private static final Logger logger = LogManager.getLogger(TrainingDataConsumer.class); + + /** + * Constructor + * + * @param trainingDataAllocation NativeMemoryAllocation that contains information about native memory allocation. + */ + public BinaryTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + super(trainingDataAllocation); + } + + @Override + public void accept(List byteVectors) { + long memoryAddress = trainingDataAllocation.getMemoryAddress(); + memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); + trainingDataAllocation.setMemoryAddress(memoryAddress); + } + + @Override + public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { + SearchHit[] hits = searchResponse.getHits().getHits(); + List vectors = new ArrayList<>(); + String[] fieldPath = fieldName.split("\\."); + int nullVectorCount = 0; + + for (int vector = 0; vector < vectorsToAdd; vector++) { + Object fieldValue = extractFieldValue(hits[vector], fieldPath); + if (fieldValue == null) { + nullVectorCount++; + continue; + } + + byte[] byteArray; + if (!(fieldValue instanceof List)) { + continue; + } + List fieldList = (List) fieldValue; + byteArray = new byte[fieldList.size()]; + for (int i = 0; i < fieldList.size(); i++) { + byteArray[i] = fieldList.get(i).byteValue(); + } + + vectors.add(byteArray); + } + + if (nullVectorCount > 0) { + logger.warn("Found {} documents with null byte vectors in field {}", nullVectorCount, fieldName); + } + + setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size()); + + accept(vectors); + } +} diff --git a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java index e838b5214..c51a96533 100644 --- a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java @@ -11,11 +11,9 @@ package org.opensearch.knn.training; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.search.SearchResponse; -import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.search.SearchHit; import java.util.ArrayList; @@ -25,7 +23,6 @@ * Transfers byte vectors from JVM to native memory. */ public class ByteTrainingDataConsumer extends TrainingDataConsumer { - private static final Logger logger = LogManager.getLogger(TrainingDataConsumer.class); /** * Constructor @@ -39,7 +36,7 @@ public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation tr @Override public void accept(List byteVectors) { long memoryAddress = trainingDataAllocation.getMemoryAddress(); - memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); + memoryAddress = JNICommons.storeByteVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); trainingDataAllocation.setMemoryAddress(memoryAddress); } @@ -48,14 +45,9 @@ public void processTrainingVectors(SearchResponse searchResponse, int vectorsToA SearchHit[] hits = searchResponse.getHits().getHits(); List vectors = new ArrayList<>(); String[] fieldPath = fieldName.split("\\."); - int nullVectorCount = 0; for (int vector = 0; vector < vectorsToAdd; vector++) { Object fieldValue = extractFieldValue(hits[vector], fieldPath); - if (fieldValue == null) { - nullVectorCount++; - continue; - } byte[] byteArray; if (!(fieldValue instanceof List)) { @@ -70,10 +62,6 @@ public void processTrainingVectors(SearchResponse searchResponse, int vectorsToA vectors.add(byteArray); } - if (nullVectorCount > 0) { - logger.warn("Found {} documents with null byte vectors in field {}", nullVectorCount, fieldName); - } - setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size()); accept(vectors); diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index fe621d7d4..959d94e3e 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -33,12 +33,23 @@ import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -591,6 +602,93 @@ public void testDocValuesWithByteVectorDataTypeFaissEngine() throws Exception { validateL2SearchResults(response); } + @SneakyThrows + public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() { + + String modelId = "test-model-ivf-byte"; + int dimension = 2; + + // Add training data + String trainIndexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .field("data_type", VectorDataType.BYTE.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + createKnnIndex(INDEX_NAME, trainIndexMapping); + + int trainingDataCount = 100; + bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension); + + XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, INDEX_NAME) + .field(TRAIN_FIELD_PARAMETER, FIELD_NAME) + .field(DIMENSION, dimension) + .field(MODEL_DESCRIPTION, "My model description") + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .field( + KNN_METHOD, + Map.of( + NAME, + METHOD_IVF, + KNN_ENGINE, + FAISS_NAME, + METHOD_PARAMETER_SPACE_TYPE, + SpaceType.L2.getValue(), + PARAMETERS, + Map.of(METHOD_PARAMETER_NLIST, 4, METHOD_PARAMETER_NPROBES, 4) + ) + ) + .endObject(); + + trainModel(modelId, trainModelXContentBuilder); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String indexName = "test-index-name-ivf-byte"; + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + + Byte[] b1 = { 6, 6 }; + addKnnDoc(indexName, "1", FIELD_NAME, b1); + Byte[] b2 = { 2, 2 }; + addKnnDoc(indexName, "2", FIELD_NAME, b2); + Byte[] b3 = { 4, 4 }; + addKnnDoc(indexName, "3", FIELD_NAME, b3); + Byte[] b4 = { 3, 3 }; + addKnnDoc(indexName, "4", FIELD_NAME, b4); + + Byte[] queryVector = { 1, 1 }; + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4); + + validateL2SearchResults(response); + deleteKNNIndex(indexName); + Thread.sleep(45 * 1000); + deleteModel(modelId); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; diff --git a/src/test/java/org/opensearch/knn/index/util/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/util/IndexUtilTests.java index f2e85b1ad..867028dd8 100644 --- a/src/test/java/org/opensearch/knn/index/util/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/util/IndexUtilTests.java @@ -29,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -39,6 +40,7 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; @@ -288,34 +290,33 @@ public void testValidateKnnField_whenTrainModelUseDifferentVectorDataTypeFromTra ); } - public void testValidateKnnField_whenPassByteVectorDataType_thenThrowException() { - Map fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "byte"); - Map top_level_field = Map.of("top_level_field", fieldValues); - Map properties = Map.of("properties", top_level_field); - String field = "top_level_field"; - int dimension = 8; - - MappingMetadata mappingMetadata = mock(MappingMetadata.class); - when(mappingMetadata.getSourceAsMap()).thenReturn(properties); - IndexMetadata indexMetadata = mock(IndexMetadata.class); - when(indexMetadata.mapping()).thenReturn(mappingMetadata); - ModelDao modelDao = mock(ModelDao.class); - - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE, null); - - assert Objects.requireNonNull(e) - .getMessage() - .matches("Validation Failed: 1: vector data type \"" + VectorDataType.BYTE.getValue() + "\" is not supported for training.;"); - } - public void testUpdateVectorDataTypeToParameters_whenVectorDataTypeIsBinary() { Map indexParams = new HashMap<>(); IndexUtil.updateVectorDataTypeToParameters(indexParams, VectorDataType.BINARY); assertEquals(VectorDataType.BINARY.getValue(), indexParams.get(VECTOR_DATA_TYPE_FIELD)); } - public void testValidateKnnField_whenPassBinaryVectorDataTypeAndPQEncoder_thenThrowException() { - Map fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "binary", "encoder", "pq"); + public void testValidateKnnField_whenPassBinaryVectorDataTypeAndEncoder_thenThrowException() { + validateKnnField_whenPassVectorDataTypeAndEncoder_thenThrowException(ENCODER_SQ, VectorDataType.BINARY); + validateKnnField_whenPassVectorDataTypeAndEncoder_thenThrowException(ENCODER_PQ, VectorDataType.BINARY); + } + + public void testValidateKnnField_whenPassByteVectorDataTypeAndEncoder_thenThrowException() { + validateKnnField_whenPassVectorDataTypeAndEncoder_thenThrowException(ENCODER_SQ, VectorDataType.BYTE); + validateKnnField_whenPassVectorDataTypeAndEncoder_thenThrowException(ENCODER_PQ, VectorDataType.BYTE); + } + + public void validateKnnField_whenPassVectorDataTypeAndEncoder_thenThrowException(String encoder, VectorDataType vectorDataType) { + Map fieldValues = Map.of( + "type", + "knn_vector", + "dimension", + 8, + "data_type", + vectorDataType.getValue(), + "encoder", + encoder + ); Map top_level_field = Map.of("top_level_field", fieldValues); Map properties = Map.of("properties", top_level_field); String field = "top_level_field"; @@ -326,24 +327,19 @@ public void testValidateKnnField_whenPassBinaryVectorDataTypeAndPQEncoder_thenTh IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.mapping()).thenReturn(mappingMetadata); ModelDao modelDao = mock(ModelDao.class); - MethodComponentContext pq = new MethodComponentContext(ENCODER_PQ, Collections.emptyMap()); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.FAISS, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_ENCODER_PARAMETER, pq)) + new MethodComponentContext( + METHOD_IVF, + ImmutableMap.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(encoder, Collections.emptyMap())) + ) ); - ValidationException e = IndexUtil.validateKnnField( - indexMetadata, - field, - dimension, - modelDao, - VectorDataType.BINARY, - knnMethodContext - ); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, vectorDataType, knnMethodContext); assert Objects.requireNonNull(e) .getMessage() - .matches("Validation Failed: 1: vector data type \"binary\" is not supported for pq encoder.;"); + .contains(String.format(Locale.ROOT, "encoder is not supported for vector data type [%s]", vectorDataType.getValue())); } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 22389ccdc..fb974b6e1 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.knn; +import com.google.common.primitives.Bytes; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; import lombok.SneakyThrows; @@ -1167,6 +1168,16 @@ public void bulkIngestRandomBinaryVectors(String indexName, String fieldName, in } } + public void bulkIngestRandomByteVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + for (int i = 0; i < numVectors; i++) { + byte[] vector = new byte[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomByte(); + } + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Bytes.asList(vector).toArray()); + } + } + /** * Bulk ingest random vectors with nested field *