diff --git a/server/src/test/java/org/opensearch/index/ShardIndexingPressureConcurrentExecutionTests.java b/server/src/test/java/org/opensearch/index/ShardIndexingPressureConcurrentExecutionTests.java index faab2f405010a..8757458e3317e 100644 --- a/server/src/test/java/org/opensearch/index/ShardIndexingPressureConcurrentExecutionTests.java +++ b/server/src/test/java/org/opensearch/index/ShardIndexingPressureConcurrentExecutionTests.java @@ -269,7 +269,13 @@ public void testCoordinatingPrimaryThreadedUpdateToShardLimitsAndRejections() th nodeStats = shardIndexingPressure.stats(); IndexingPressurePerShardStats shardStoreStats = shardIndexingPressure.shardStats().getIndexingPressureShardStats(shardId1); - assertNull(shardStoreStats); + // If rejection count equals NUM_THREADS that means rejections happened until the last request, then we'll get shardStoreStats which + // was updated on the last request. In other cases, the shardStoreStats simply moves to the cold store and null is returned. + if (rejectionCount.get() == NUM_THREADS) { + assertEquals(10, shardStoreStats.getCurrentPrimaryAndCoordinatingLimits()); + } else { + assertNull(shardStoreStats); + } shardStats = shardIndexingPressure.coldStats(); if (randomBoolean) { assertEquals(rejectionCount.get(), nodeStats.getCoordinatingRejections()); @@ -331,7 +337,13 @@ public void testReplicaThreadedUpdateToShardLimitsAndRejections() throws Excepti assertEquals(0, nodeStats.getCurrentReplicaBytes()); IndexingPressurePerShardStats shardStoreStats = shardIndexingPressure.shardStats().getIndexingPressureShardStats(shardId1); - assertNull(shardStoreStats); + // If rejection count equals NUM_THREADS that means rejections happened until the last request, then we'll get shardStoreStats which + // was updated on the last request. In other cases, the shardStoreStats simply moves to the cold store and null is returned. + if (rejectionCount.get() == NUM_THREADS) { + assertEquals(15, shardStoreStats.getCurrentReplicaLimits()); + } else { + assertNull(shardStoreStats); + } shardStats = shardIndexingPressure.coldStats(); assertEquals(rejectionCount.get(), shardStats.getIndexingPressureShardStats(shardId1).getReplicaNodeLimitsBreachedRejections());