diff --git a/CHANGELOG.md b/CHANGELOG.md index 5615509de..6871074f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Optimize reduceToTopK in ResultUtil by removing pre-filling and reducing peek calls [#2146](https://github.com/opensearch-project/k-NN/pull/2146) * Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149) * KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155) +* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172) ### Bug Fixes * KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 5fcc51bb5..1753140e6 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -88,6 +88,7 @@ public class KNNSettings { public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit"; public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; + public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; /** * Default setting values @@ -112,11 +113,31 @@ public class KNNSettings { public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; + public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true; /** * Settings Definition */ + /** + * This setting controls whether shard-level re-scoring for KNN disk-based vectors is turned off. + * The setting uses: + * + * + * @see Setting#boolSetting(String, boolean, Setting.Property...) + */ + public static final Setting KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting( + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE, + IndexScope, + Dynamic + ); + // This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default // 1% of the JVM heap public static final Setting KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting( @@ -454,6 +475,10 @@ private Setting getSetting(String key) { return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; } + if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { + return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -475,7 +500,8 @@ public List> getSettings() { KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, KNN_FAISS_AVX512_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, - QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING + QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -528,6 +554,14 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(indexName) + .getSettings() + .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index 3e1b47db7..c9a169efc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -97,32 +97,35 @@ public static boolean isConfigured(CompressionLevel compressionLevel) { /** * Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}. * - *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of - * {@code dimension}: + *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the + * {@code dimension} value: *

    - *
  • If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an - * oversample factor of 5.0f.
  • - *
  • If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with - * the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.
  • + *
  • If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).
  • + *
  • If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).
  • + *
  • If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).
  • *
- * If the {@code mode} is not valid, the method returns {@code null}. + * If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}. * * @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}. * @param dimension The dimensional value that determines the {@link RescoreContext} behavior. - * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than - * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode - * is invalid. + * @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode + * is not valid. */ public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) { if (modesForRescore.contains(mode)) { // Adjust RescoreContext based on dimension - if (dimension <= RescoreContext.DIMENSION_THRESHOLD) { - // For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD).build(); + if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) { + // No oversampling for dimensions >= 1000 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build(); + } else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) { + // 2x oversampling for dimensions >= 768 but < 1000 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build(); } else { - return defaultRescoreContext; + // 3x oversampling for dimensions < 768 + return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build(); } } return null; } + } diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 945da850a..adb2875d5 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; @@ -54,7 +55,6 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); int finalK = knnQuery.getK(); @@ -63,7 +63,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo } else { int firstPassK = rescoreContext.getFirstPassK(finalK); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - ResultUtil.reduceToTopK(perLeafResults, firstPassK); + if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) { + ResultUtil.reduceToTopK(perLeafResults, firstPassK); + } StopWatch stopWatch = new StopWatch().start(); perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 51d4e491c..a2563b2a6 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -24,6 +24,15 @@ public final class RescoreContext { public static final int DIMENSION_THRESHOLD = 1000; public static final float OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD = 5.0f; + // Dimension thresholds for adjusting oversample factor + public static final int DIMENSION_THRESHOLD_1000 = 1000; + public static final int DIMENSION_THRESHOLD_768 = 768; + + // Oversample factors based on dimension thresholds + public static final float OVERSAMPLE_FACTOR_1000 = 1.0f; // No oversampling for dimensions >= 1000 + public static final float OVERSAMPLE_FACTOR_768 = 2.0f; // 2x oversampling for dimensions >= 768 and < 1000 + public static final float OVERSAMPLE_FACTOR_BELOW_768 = 3.0f; // 3x oversampling for dimensions < 768 + // Todo:- We will improve this in upcoming releases public static final int MIN_FIRST_PASS_RESULTS = 100; diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index 75eb14713..fd25699cc 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -158,6 +158,41 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { assertEquals(userProvidedEfSearch, efSearchValue); } + @SneakyThrows + public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + mockNode.close(); + assertTrue(shardLevelRescoringDisabled); + } + + @SneakyThrows + public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() { + boolean userDefinedRescoringDisabled = false; + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + final Settings rescoringDisabledSetting = Settings.builder() + .put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled) + .build(); + + mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); + + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + mockNode.close(); + assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); + } + @SneakyThrows public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenValidateAndSucceed() { boolean expectedKNNFaissAVX2Disabled = true; diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java index cc70d4c2d..57372b11e 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -44,65 +44,84 @@ public void testIsConfigured() { public void testGetDefaultRescoreContext() { // Test rescore context for ON_DISK mode Mode mode = Mode.ON_DISK; - int belowThresholdDimension = 500; // A dimension below the threshold - int aboveThresholdDimension = 1500; // A dimension above the threshold - // x32 with dimension <= 1000 should have an oversample factor of 5.0f + // Test various dimensions based on the updated oversampling logic + int belowThresholdDimension = 500; // A dimension below 768 + int between768and1000Dimension = 800; // A dimension between 768 and 1000 + int above1000Dimension = 1500; // A dimension above 1000 + + // Compression level x32 with dimension < 768 should have an oversample factor of 3.0f RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x32 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x16 with dimension <= 1000 should have an oversample factor of 5.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); + // Compression level x32 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x16 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x16 with dimension < 768 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x8 with dimension <= 1000 should have an oversample factor of 5.0f + // Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension); + assertNotNull(rescoreContext); + assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x16 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension); + assertNotNull(rescoreContext); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x8 with dimension < 768 should have an oversample factor of 3.0f rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x8 with dimension > 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension); assertNotNull(rescoreContext); assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext) + // Compression level x8 with dimension > 1000 should have no oversampling (1.0f) + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension); + assertNotNull(rescoreContext); + assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + + // Compression level x4 with dimension < 768 should return null (no RescoreContext) rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension); - assertNull(rescoreContext); - // Other compression levels should behave similarly with respect to dimension + // Compression level x4 with dimension > 1000 should return null (no RescoreContext) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension); + assertNull(rescoreContext); + // Compression level x2 with dimension < 768 should return null rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x2 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x2 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension); assertNull(rescoreContext); + // Compression level x1 with dimension < 768 should return null rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - // x1 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension); + // Compression level x1 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension); assertNull(rescoreContext); - // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f + // NOT_CONFIGURED mode should return null for any dimension rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - } + } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 06350f39c..7fd96c6df 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -17,11 +17,16 @@ import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.invocation.InvocationOnMock; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.ResultUtil; @@ -35,12 +40,11 @@ import java.util.Map; import java.util.concurrent.Callable; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.times; import static org.mockito.MockitoAnnotations.openMocks; public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @@ -66,6 +70,9 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @Mock private LeafReader leafReader2; + @Mock + private ClusterService clusterService; + @InjectMocks private NativeEngineKnnVectorQuery objectUnderTest; @@ -91,6 +98,11 @@ public void setUp() throws Exception { }); when(reader.getContext()).thenReturn(indexReaderContext); + + when(clusterService.state()).thenReturn(mock(ClusterState.class)); // Mock ClusterState + + // Set ClusterService in KNNSettings + KNNSettings.state().setClusterService(clusterService); } @SneakyThrows @@ -127,6 +139,49 @@ public void testMultiLeaf() { assertEquals(expected, actual.getQuery()); } + @SneakyThrows + public void testRescoreWhenShardLevelRescoringEnabled() { + // Given + List leaves = List.of(leaf1, leaf2); + when(reader.leaves()).thenReturn(leaves); + + int k = 2; + int firstPassK = 3; + Map initialLeaf1Results = new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f)); + Map initialLeaf2Results = new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f)); + Map rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f)); + Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); + + when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); + when(knnQuery.getK()).thenReturn(k); + when(knnWeight.getQuery()).thenReturn(knnQuery); + when(knnWeight.searchLeaf(leaf1, firstPassK)).thenReturn(initialLeaf1Results); + when(knnWeight.searchLeaf(leaf2, firstPassK)).thenReturn(initialLeaf2Results); + when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); + when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + + try ( + MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); + MockedStatic mockedResultUtil = mockStatic(ResultUtil.class) + ) { + + // When shard-level re-scoring is enabled + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); + + // Mock ResultUtil to return valid TopDocs + mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) + .thenReturn(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0])); + mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenCallRealMethod(); + + // When + Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + + // Then + mockedResultUtil.verify(() -> ResultUtil.reduceToTopK(any(), anyInt()), times(2)); + assertNotNull(actual); + } + } + @SneakyThrows public void testSingleLeaf() { // Given @@ -188,7 +243,15 @@ public void testRescore() { when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); - try (MockedStatic mockedResultUtil = mockStatic(ResultUtil.class)) { + + try ( + MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); + MockedStatic mockedResultUtil = mockStatic(ResultUtil.class) + ) { + + // When shard-level re-scoring is enabled + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(true); + mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf2Results), anyInt())).thenAnswer(t -> topDocs2);