From 739667f710b273861684ed03364227578c6c9604 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 19 Jul 2022 12:31:16 -0700 Subject: [PATCH] AD model performance benchmark This PR adds a HCAD model performance benchmark so that we can compare model performance across versions. Regarding benchmark data, we randomly generated synthetic data with known anomalies inserted throughout the signal. In particular, these are one/two/four dimensional data where each dimension is a noisy cosine wave. Anomalies are inserted into one dimension with 0.003 probability. Anomalies across each dimension can be independent or dependent. We have approximately 5000 observations per data set. The data set is generated using the same random seed so the result is comparable across versions. We also backported #600 so that we can capture the performance data in CI output. Testing done: * added unit tests to run the benchmark. Signed-off-by: Kaituo Li --- .github/workflows/benchmark.yml | 43 +++ build.gradle | 27 ++ .../ad/e2e/DetectionResultEvalutationIT.java | 1 + .../ad/ml/AbstractModelPerfTest.java | 247 +++++++++++++ .../ad/ml/EntityColdStarterTests.java | 227 +----------- .../opensearch/ad/ml/HCADModelPerfTests.java | 334 ++++++++++++++++++ .../ad/util/LabelledAnomalyGenerator.java | 40 ++- 7 files changed, 696 insertions(+), 223 deletions(-) create mode 100644 .github/workflows/benchmark.yml create mode 100644 src/test/java/org/opensearch/ad/ml/AbstractModelPerfTest.java create mode 100644 src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..3bcea30c0 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,43 @@ +name: Run AD benchmark +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + +jobs: + Build-ad: + strategy: + matrix: + java: [8, 11, 14] + fail-fast: false + + name: Run Anomaly detection model performance benchmark + runs-on: ubuntu-latest + + steps: + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + # anomaly-detection + - name: Checkout AD + uses: actions/checkout@v2 + + - name: Assemble anomaly-detection + run: | + ./gradlew assemble -Dopensearch.version=1.3.4-SNAPSHOT + echo "Creating ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT ..." + mkdir -p ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT + echo "Copying ./build/distributions/*.zip to ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT ..." + ls ./build/distributions/ + cp ./build/distributions/*.zip ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT + echo "Copied ./build/distributions/*.zip to ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT ..." + ls ./src/test/resources/org/opensearch/ad/bwc/anomaly-detection/1.3.4.0-SNAPSHOT + - name: Build and Run Tests + run: | + ./gradlew ':test' --tests "org.opensearch.ad.ml.HCADModelPerfTests" -Dtests.seed=2AEBDBBAE75AC5E0 -Dtests.security.manager=false -Dtests.locale=es-CU -Dtests.timezone=Chile/EasterIsland -Dtest.logs=true -Dhcad-benchmark=true + ./gradlew integTest --tests "org.opensearch.ad.e2e.DetectionResultEvalutationIT.testDataset" -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true \ No newline at end of file diff --git a/build.gradle b/build.gradle index 6f8fc8ff8..092f62042 100644 --- a/build.gradle +++ b/build.gradle @@ -36,6 +36,10 @@ buildscript { 'opensearch-anomaly-detection-1.1.0.0.zip' bwcOpenSearchJSDownload = 'https://ci.opensearch.org/ci/dbc/bundle-build/1.1.0/20210930/linux/x64/builds/opensearch/plugins/' + 'opensearch-job-scheduler-1.1.0.0.zip' + // gradle build won't print logs during test by default unless there is a failure. + // It is useful to record intermediately information like prediction precision and recall. + // This option turn on log printing during tests. + printLogs = "true" == System.getProperty("test.logs", "false") } repositories { @@ -175,6 +179,12 @@ test { } include '**/*Tests.class' systemProperty 'tests.security.manager', 'false' + + if (System.getProperty("hcad-benchmark") == null) { + filter { + excludeTestsMatching "org.opensearch.ad.ml.HCADModelPerfTests" + } + } } task integTest(type: RestIntegTestTask) { @@ -240,6 +250,12 @@ integTest { jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' } + if (printLogs) { + testLogging { + showStandardStreams = true + outputs.upToDateWhen {false} + } + } } testClusters.integTest { @@ -670,6 +686,7 @@ dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.7.2' testRuntimeOnly 'org.junit.vintage:junit-vintage-engine:5.7.2' testCompileOnly 'junit:junit:4.13.2' + implementation group: 'org.javassist', name: 'javassist', version:'3.28.0-GA' } compileJava.options.compilerArgs << "-Xlint:-deprecation,-rawtypes,-serial,-try,-unchecked" @@ -775,3 +792,13 @@ task updateVersion { ant.replaceregexp(file:'build.gradle', match: '"opensearch.version", "\\d.*"', replace: '"opensearch.version", "' + newVersion.tokenize('-')[0] + '-SNAPSHOT"', flags:'g', byline:true) } } + +// show test results so that we can record information like precion/recall results of correctness testing. +if (printLogs) { + test { + testLogging { + showStandardStreams = true + outputs.upToDateWhen {false} + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index ecfb2eee5..d123afa32 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -121,6 +121,7 @@ private void verifyTestResults( assertTrue(recall >= minRecall); assertTrue(errors <= maxError); + LOG.info("Precision: {}, Window recall: {}", precision, recall); } private int isAnomaly(Instant time, List> labels) { diff --git a/src/test/java/org/opensearch/ad/ml/AbstractModelPerfTest.java b/src/test/java/org/opensearch/ad/ml/AbstractModelPerfTest.java new file mode 100644 index 000000000..5bc0c0cab --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/AbstractModelPerfTest.java @@ -0,0 +1,247 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.Interpolator; +import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class AbstractModelPerfTest extends AbstractADTest { + protected int numMinSamples; + protected String modelId; + protected String entityName; + protected String detectorId; + protected ModelState modelState; + protected Clock clock; + protected float priority; + protected EntityColdStarter entityColdStarter; + protected NodeStateManager stateManager; + protected SearchFeatureDao searchFeatureDao; + protected Interpolator interpolator; + protected CheckpointDao checkpoint; + protected FeatureManager featureManager; + protected Settings settings; + protected ThreadPool threadPool; + protected AtomicBoolean released; + protected Runnable releaseSemaphore; + protected ActionListener listener; + protected CountDownLatch inProgressLatch; + protected CheckpointWriteWorker checkpointWriteQueue; + protected Entity entity; + protected AnomalyDetector detector; + protected long rcfSeed; + protected ClientUtil clientUtil; + protected ModelManager modelManager; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + settings = Settings.EMPTY; + + Client client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + when(clock.millis()).thenReturn(1602401500000L); + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + DiscoveryNode discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = + new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); + interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + + searchFeatureDao = mock(SearchFeatureDao.class); + checkpoint = mock(CheckpointDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + rcfSeed = 2051L; + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + detectorId = "123"; + modelId = "123_entity_abc"; + entityName = "abc"; + priority = 0.3f; + entity = Entity.createSingleAttributeEntity("field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class) + ); + } + + protected void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); + } + + public int searchInsert(long[] timestamps, long target) { + int pivot, left = 0, right = timestamps.length - 1; + while (left <= right) { + pivot = left + (right - left) / 2; + if (timestamps[pivot] == target) + return pivot; + if (target < timestamps[pivot]) + right = pivot - 1; + else + left = pivot + 1; + } + return left; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 68882aeb8..d375b433d 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -14,68 +14,37 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.IOException; -import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.Queue; import java.util.Random; -import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AbstractADTest; -import org.opensearch.ad.AnomalyDetectorPlugin; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.Interpolator; -import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodeRole; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; -import org.opensearch.test.ClusterServiceUtils; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; import test.org.opensearch.ad.util.LabelledAnomalyGenerator; import test.org.opensearch.ad.util.MLUtil; @@ -86,174 +55,7 @@ import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; -public class EntityColdStarterTests extends AbstractADTest { - int numMinSamples; - String modelId; - String entityName; - String detectorId; - ModelState modelState; - Clock clock; - float priority; - EntityColdStarter entityColdStarter; - NodeStateManager stateManager; - SearchFeatureDao searchFeatureDao; - Interpolator interpolator; - CheckpointDao checkpoint; - FeatureManager featureManager; - Settings settings; - ThreadPool threadPool; - AtomicBoolean released; - Runnable releaseSemaphore; - ActionListener listener; - CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; - Entity entity; - AnomalyDetector detector; - long rcfSeed; - ModelManager modelManager; - ClientUtil clientUtil; - - @SuppressWarnings("unchecked") - @Override - public void setUp() throws Exception { - super.setUp(); - numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; - - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); - - threadPool = mock(ThreadPool.class); - setUpADThreadPool(threadPool); - - settings = Settings.EMPTY; - - Client client = mock(Client.class); - clientUtil = mock(ClientUtil.class); - - detector = TestHelpers.AnomalyDetectorBuilder - .newInstance() - .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) - .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) - .build(); - when(clock.millis()).thenReturn(1602401500000L); - doAnswer(invocation -> { - GetRequest request = invocation.getArgument(0); - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - - return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); - - Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); - - DiscoveryNode discoveryNode = new DiscoveryNode( - "node1", - OpenSearchTestCase.buildNewFakeTransportAddress(), - Collections.emptyMap(), - DiscoveryNodeRole.BUILT_IN_ROLES, - Version.CURRENT - ); - - ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - - stateManager = new NodeStateManager( - client, - xContentRegistry(), - settings, - clientUtil, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - - SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = - new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); - interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); - - searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); - - featureManager = new FeatureManager( - searchFeatureDao, - interpolator, - clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, - AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, - AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, - AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, - AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME - ); - - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - - rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( - clock, - threadPool, - stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, - numMinSamples, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, - searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - featureManager, - settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS - ); - - detectorId = "123"; - modelId = "123_entity_abc"; - entityName = "abc"; - priority = 0.3f; - entity = Entity.createSingleAttributeEntity("field", entityName); - - released = new AtomicBoolean(); - - inProgressLatch = new CountDownLatch(1); - releaseSemaphore = () -> { - released.set(true); - inProgressLatch.countDown(); - }; - listener = ActionListener.wrap(releaseSemaphore); - - modelManager = new ModelManager( - mock(CheckpointDao.class), - mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - entityColdStarter, - mock(FeatureManager.class), - mock(MemoryTracker.class) - ); - } - - private void checkSemaphoreRelease() throws InterruptedException { - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - assertTrue(released.get()); - } +public class EntityColdStarterTests extends AbstractModelPerfTest { // train using samples directly public void testTrainUsingSamples() throws InterruptedException { @@ -724,7 +526,18 @@ private void accuracyTemplate(int detectorIntervalMins) throws Exception { System.out.println("seed = " + seed); // create labelled data MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator - .getMultiDimData(dataSize + detector.getShingleSize() - 1, 50, 100, 5, seed, baseDimension, false, trainTestSplit, delta); + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + false + ); long[] timestamps = dataWithKeys.timestampsMs; double[][] data = dataWithKeys.data; when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); @@ -827,20 +640,6 @@ public int compare(Entry p1, Entry p2) { assertTrue("recall is " + recall, recall >= 0.5); } - public int searchInsert(long[] timestamps, long target) { - int pivot, left = 0, right = timestamps.length - 1; - while (left <= right) { - pivot = left + (right - left) / 2; - if (timestamps[pivot] == target) - return pivot; - if (target < timestamps[pivot]) - right = pivot - 1; - else - left = pivot + 1; - } - return left; - } - public void testAccuracyTenMinuteInterval() throws Exception { accuracyTemplate(10); } diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java new file mode 100644 index 000000000..126e4215d --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -0,0 +1,334 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.lucene.util.TimeUnits; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.AnomalyDetectorSettings; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import com.google.common.collect.ImmutableList; + +@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) // rcf may be slow due to bounding box cache disabled +public class HCADModelPerfTests extends AbstractModelPerfTest { + + /** + * A template to perform precision/recall test by simulating HCAD logic with only one entity. + * + * @param detectorIntervalMins Detector interval + * @param precisionThreshold precision threshold + * @param recallThreshold recall threshold + * @param baseDimension the number of dimensions + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @throws Exception when failing to create anomaly detector or creating training data + */ + @SuppressWarnings("unchecked") + private void averageAccuracyTemplate( + int detectorIntervalMins, + float precisionThreshold, + float recallThreshold, + int baseDimension, + boolean anomalyIndependent + ) throws Exception { + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 10; + double prec = 0; + double recall = 0; + double totalPrec = 0; + double totalRecall = 0; + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + for (int z = 1; z <= numberOfTrials; z++) { + long seed = z; + LOG.info("seed = " + seed); + + searchFeatureDao = mock(SearchFeatureDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + seed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class) + ); + + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + anomalyIndependent + ); + + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entity = Entity.createSingleAttributeEntity("field", entityName + z); + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + ModelState modelState = new ModelState<>( + model, + entity.getModelId(detectorId).get(), + detector.getDetectorId(), + ModelType.ENTITY.getName(), + clock, + priority + ); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getDetectorId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + totalPrec += prec; + totalRecall += recall; + modelState = null; + dataWithKeys = null; + reset(searchFeatureDao); + searchFeatureDao = null; + } + + double avgPrec = totalPrec / numberOfTrials; + double avgRecall = totalRecall / numberOfTrials; + LOG.info("{} features, Interval {}, Precision: {}, recall: {}", baseDimension, detectorIntervalMins, avgPrec, avgRecall); + assertTrue("average precision is " + avgPrec, avgPrec >= precisionThreshold); + assertTrue("average recall is " + avgRecall, avgRecall >= recallThreshold); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyDependent() throws Exception { + LOG.info("Anomalies are injected dependently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.4f, 0.3f, 4, false); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, false); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 1, false); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.4f, 0.3f, 4, false); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, false); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 1, false); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyIndependent() throws Exception { + LOG.info("Anomalies are injected independently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.3f, 0.1f, 4, true); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, true); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.3f, 0.4f, 1, true); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.2f, 0.1f, 4, true); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, true); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.3f, 0.4f, 1, true); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java index f2ef3cc2d..f77c135fb 100644 --- a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java +++ b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java @@ -29,6 +29,7 @@ public class LabelledAnomalyGenerator { * @param useSlope whether to use slope in cosine data * @param historicalData the number of historical points relative to now * @param delta point interval + * @param anomalyIndependent whether anomalies in each dimension is generated independently * @return the labelled data */ public static MultiDimDataWithTime getMultiDimData( @@ -40,7 +41,8 @@ public static MultiDimDataWithTime getMultiDimData( int baseDimension, boolean useSlope, int historicalData, - int delta + int delta, + boolean anomalyIndependent ) { double[][] data = new double[num][]; long[] timestamps = new long[num]; @@ -66,14 +68,34 @@ public static MultiDimDataWithTime getMultiDimData( startEpochMs += delta; data[i] = new double[baseDimension]; double[] newChange = new double[baseDimension]; - for (int j = 0; j < baseDimension; j++) { - data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); - if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { - double factor = 5 * (1 + noiseprg.nextDouble()); - double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; - data[i][j] += newChange[j] = change; - changedTimestamps[i] = timestamps[i]; - changes[i] = newChange; + // decide whether we should inject anomalies at this point + // If we do this for each dimension, each dimension's anomalies + // are independent and will make it harder for RCF to detect anomalies. + // Doing it in point level will make each dimension's anomalies + // correlated. + if (anomalyIndependent) { + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } else { + boolean flag = (noiseprg.nextDouble() < 0.01); + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + // adding the condition < 0.3 so there is still some variance if all features have an anomaly or not + if (flag && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } } } }