From 1e01657f9a81a06801375ef3d27bae4e45c088e2 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 19 Jul 2022 12:31:16 -0700 Subject: [PATCH] AD model performance benchmark This PR adds an AD model performance benchmark so that we can compare model performance across versions. For the single stream detector, we refactored tests in DetectionResultEvalutationIT and moved it to SingleStreamModelPerfIT. For the HCAD detector, we randomly generated synthetic data with known anomalies inserted throughout the signal. In particular, these are one/two/four dimensional data where each dimension is a noisy cosine wave. Anomalies are inserted into one dimension with 0.003 probability. Anomalies across each dimension can be independent or dependent. We have approximately 5000 observations per data set. The data set is generated using the same random seed so the result is comparable across versions. We also backported #600 so that we can capture the performance data in CI output. We also fixed #712 by revising the client setup code. Testing done: * added unit tests to run the benchmark. Signed-off-by: Kaituo Li --- .github/workflows/benchmark.yml | 33 ++ .../workflows/test_build_multi_platform.yml | 2 +- build.gradle | 20 +- .../org/opensearch/ad/ODFERestTestCase.java | 23 +- .../ad/e2e/AbstractSyntheticDataTest.java | 242 +++++++++++ .../ad/e2e/DetectionResultEvalutationIT.java | 394 +----------------- .../ad/e2e/SingleStreamModelPerfIT.java | 230 ++++++++++ .../ad/ml/AbstractCosineDataTest.java | 256 ++++++++++++ .../ad/ml/EntityColdStarterTests.java | 225 +--------- .../opensearch/ad/ml/HCADModelPerfTests.java | 342 +++++++++++++++ .../ad/rest/AnomalyDetectorRestApiIT.java | 3 +- .../ad/util/LabelledAnomalyGenerator.java | 40 +- 12 files changed, 1180 insertions(+), 630 deletions(-) create mode 100644 .github/workflows/benchmark.yml create mode 100644 src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java create mode 100644 src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java create mode 100644 src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java create mode 100644 src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..f01b60c9a --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,33 @@ +name: Build and Test Anomaly detection +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + +jobs: + Build-ad: + strategy: + matrix: + java: [17] + fail-fast: false + + name: Build and Test Anomaly detection Plugin + runs-on: ubuntu-latest + + steps: + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + # anomaly-detection + - name: Checkout AD + uses: actions/checkout@v2 + + - name: Build and Run Tests + run: | + ./gradlew ':test' --tests "org.opensearch.ad.ml.HCADModelPerfTests" -Dtests.seed=2AEBDBBAE75AC5E0 -Dtests.security.manager=false -Dtests.locale=es-CU -Dtests.timezone=Chile/EasterIsland -Dtest.logs=true -Dmodel-benchmark=true + ./gradlew integTest --tests "org.opensearch.ad.e2e.SingleStreamModelPerfIT" -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true -Dmodel-benchmark=true \ No newline at end of file diff --git a/.github/workflows/test_build_multi_platform.yml b/.github/workflows/test_build_multi_platform.yml index 50e86f785..ffc2aa8a3 100644 --- a/.github/workflows/test_build_multi_platform.yml +++ b/.github/workflows/test_build_multi_platform.yml @@ -60,7 +60,7 @@ jobs: ./gradlew assemble - name: Build and Run Tests run: | - ./gradlew build -Dtest.logs=true + ./gradlew build - name: Publish to Maven Local run: | ./gradlew publishToMavenLocal diff --git a/build.gradle b/build.gradle index cecc314ff..cf08bf7d5 100644 --- a/build.gradle +++ b/build.gradle @@ -139,7 +139,7 @@ configurations.all { if (it.state != Configuration.State.UNRESOLVED) return resolutionStrategy { force "joda-time:joda-time:${versions.joda}" - force "com.fasterxml.jackson.core:jackson-core:2.13.4" + force "com.fasterxml.jackson.core:jackson-core:2.14.0" force "commons-logging:commons-logging:${versions.commonslogging}" force "org.apache.httpcomponents:httpcore5:${versions.httpcore5}" force "commons-codec:commons-codec:${versions.commonscodec}" @@ -219,6 +219,12 @@ test { } include '**/*Tests.class' systemProperty 'tests.security.manager', 'false' + + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.ml.HCADModelPerfTests" + } + } } task integTest(type: RestIntegTestTask) { @@ -264,6 +270,12 @@ integTest { } } + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.e2e.SingleStreamModelPerfIT" + } + } + // The 'doFirst' delays till execution time. doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -664,9 +676,9 @@ dependencies { implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' // force Jackson version to avoid version conflict issue - implementation "com.fasterxml.jackson.core:jackson-core:2.13.4" - implementation "com.fasterxml.jackson.core:jackson-databind:2.13.4.2" - implementation "com.fasterxml.jackson.core:jackson-annotations:2.13.4" + implementation "com.fasterxml.jackson.core:jackson-core:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-databind:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.0" // used for serializing/deserializing rcf models. implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java b/src/test/java/org/opensearch/ad/ODFERestTestCase.java index cf89f5c85..44b3e1d2d 100644 --- a/src/test/java/org/opensearch/ad/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java @@ -11,6 +11,8 @@ package org.opensearch.ad; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; @@ -186,21 +188,18 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s .ofNullable(System.getProperty("password")) .orElseThrow(() -> new RuntimeException("password is missing")); BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider - .setCredentials( - new AuthScope(new HttpHost("localhost", 9200)), - new UsernamePasswordCredentials(userName, password.toCharArray()) - ); + final AuthScope anyScope = new AuthScope(null, -1); + credentialsProvider.setCredentials(anyScope, new UsernamePasswordCredentials(userName, password.toCharArray())); try { final TlsStrategy tlsStrategy = ClientTlsStrategyBuilder .create() - .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) - // disable the certificate since our testing cluster just uses the default security configuration .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) .build(); - final PoolingAsyncClientConnectionManager connectionManager = PoolingAsyncClientConnectionManagerBuilder .create() + .setMaxConnPerRoute(DEFAULT_MAX_CONN_PER_ROUTE) + .setMaxConnTotal(DEFAULT_MAX_CONN_TOTAL) .setTlsStrategy(tlsStrategy) .build(); return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(connectionManager); @@ -212,8 +211,12 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); final TimeValue socketTimeout = TimeValue .parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); - builder - .setRequestConfigCallback(conf -> conf.setResponseTimeout(Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())))); + builder.setRequestConfigCallback(conf -> { + Timeout timeout = Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())); + conf.setConnectTimeout(timeout); + conf.setResponseTimeout(timeout); + return conf; + }); if (settings.hasValue(CLIENT_PATH_PREFIX)) { builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); } diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java new file mode 100644 index 000000000..9458bc740 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -0,0 +1,242 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.ad.ODFERestTestCase; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class AbstractSyntheticDataTest extends ODFERestTestCase { + /** + * In real time AD, we mute a node for a detector if that node keeps returning + * ResourceNotFoundException (5 times in a row). This is a problem for batch mode + * testing as we issue a large amount of requests quickly. Due to the speed, we + * won't be able to finish cold start before the ResourceNotFoundException mutes + * a node. Since our test case has only one node, there is no other nodes to fall + * back on. Here we disable such fault tolerance by setting max retries before + * muting to a large number and the actual wait time during muting to 0. + * + * @throws IOException when failing to create http request body + */ + protected void disableResourceNotFoundFaultTolerence() throws IOException { + XContentBuilder settingCommand = JsonXContent.contentBuilder(); + + settingCommand.startObject(); + settingCommand.startObject("persistent"); + settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.endObject(); + settingCommand.endObject(); + Request request = new Request("PUT", "/_cluster/settings"); + request.setJsonEntity(Strings.toString(settingCommand)); + + adminClient().performRequest(request); + } + + protected List getData(String datasetFileName) throws Exception { + JsonArray jsonArray = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List list = new ArrayList<>(jsonArray.size()); + jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); + return list; + } + + protected Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { + try { + Request request = new Request( + "POST", + String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) + ); + request + .setJsonEntity( + String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) + ); + return entityAsMap(client.performRequest(request)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected void bulkIndexTrainData( + String datasetName, + List data, + int trainTestSplit, + RestClient client, + String categoryField + ) throws Exception { + Request request = new Request("PUT", datasetName); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoryField + ); + } + + request.setJsonEntity(requestBody); + setWarningHandler(request, false); + client.performRequest(request); + Thread.sleep(1_000); + + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = 0; i < trainTestSplit; i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); + } + + protected String createDetector( + String datasetName, + int intervalMinutes, + RestClient client, + String categoryField, + long windowDelayInMins + ) throws Exception { + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + windowDelayInMins + ); + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + categoryField, + windowDelayInMins + ); + } + + request.setJsonEntity(requestBody); + Map response = entityAsMap(client.performRequest(request)); + String detectorId = (String) response.get("_id"); + Thread.sleep(1_000); + return detectorId; + } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } + + protected void setWarningHandler(Request request, boolean strictDeprecationMode) { + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 145cf664b..0c2cce832 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -12,212 +12,38 @@ package org.opensearch.ad.e2e; import static org.opensearch.ad.TestHelpers.toHttpEntity; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.Charset; import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Instant; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; -import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Logger; -import org.opensearch.ad.ODFERestTestCase; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; -import org.opensearch.client.WarningsHandler; -import org.opensearch.common.Strings; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.support.XContentMapValues; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; -public class DetectionResultEvalutationIT extends ODFERestTestCase { +public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { protected static final Logger LOG = (Logger) LogManager.getLogger(DetectionResultEvalutationIT.class); - public void testDataset() throws Exception { - // TODO: this test case will run for a much longer time and timeout with security enabled - if (!isHttps()) { - disableResourceNotFoundFaultTolerence(); - verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); - } - } - - private void verifyAnomaly( - String datasetName, - int intervalMinutes, - int trainTestSplit, - int shingleSize, - double minPrecision, - double minRecall, - double maxError - ) throws Exception { - RestClient client = client(); - - String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); - String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); - - List data = getData(dataFileName); - List> anomalies = getAnomalyWindows(labelFileName); - - bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); - // single-stream detector can use window delay 0 here because we give the run api the actual data time - String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); - simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); - bulkIndexTestData(data, datasetName, trainTestSplit, client); - double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); - verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); - } - - private void verifyTestResults( - double[] testResults, - List> anomalies, - double minPrecision, - double minRecall, - double maxError - ) { - - double positives = testResults[0]; - double truePositives = testResults[1]; - double positiveAnomalies = testResults[2]; - double errors = testResults[3]; - - // precision = predicted anomaly points that are true / predicted anomaly points - double precision = positives > 0 ? truePositives / positives : 1; - assertTrue(precision >= minPrecision); - - // recall = windows containing predicted anomaly points / total anomaly windows - double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; - assertTrue(recall >= minRecall); - - assertTrue(errors <= maxError); - LOG.info("Precision: {}, Window recall: {}", precision, recall); - } - - private int isAnomaly(Instant time, List> labels) { - for (int i = 0; i < labels.size(); i++) { - Entry window = labels.get(i); - if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { - return i; - } - } - return -1; - } - - private double[] getTestResults( - String detectorId, - List data, - int trainTestSplit, - int intervalMinutes, - List> anomalies, - RestClient client - ) throws Exception { - - double positives = 0; - double truePositives = 0; - Set positiveAnomalies = new HashSet<>(); - double errors = 0; - for (int i = trainTestSplit; i < data.size(); i++) { - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); - Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - Map response = getDetectionResult(detectorId, begin, end, client); - double anomalyGrade = (double) response.get("anomalyGrade"); - if (anomalyGrade > 0) { - positives++; - int result = isAnomaly(begin, anomalies); - if (result != -1) { - truePositives++; - positiveAnomalies.add(result); - } - } - } catch (Exception e) { - errors++; - logger.error("failed to get detection results", e); - } - } - return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; - } - - /** - * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) - * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. - * @param detectorId Detector Id - * @param data Data in Json format - * @param trainTestSplit Training data size - * @param shingleSize Shingle size - * @param intervalMinutes Detector Interval - * @param client OpenSearch Client - * @throws Exception when failing to query/indexing from/to OpenSearch - */ - private void simulateSingleStreamStartDetector( - String detectorId, - List data, - int trainTestSplit, - int shingleSize, - int intervalMinutes, - RestClient client - ) throws Exception { - - Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); - - Instant begin = null; - Instant end = null; - for (int i = 0; i < shingleSize; i++) { - begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); - end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - getDetectionResult(detectorId, begin, end, client); - } catch (Exception e) {} - } - // It takes time to wait for model initialization - long startTime = System.currentTimeMillis(); - do { - try { - Thread.sleep(5_000); - getDetectionResult(detectorId, begin, end, client); - break; - } catch (Exception e) { - long duration = System.currentTimeMillis() - startTime; - // we wait at most 60 secs - if (duration > 60_000) { - throw new RuntimeException(e); - } - } - } while (true); - } - /** * Simulate starting the given HCAD detector. * @param detectorId Detector Id @@ -274,224 +100,6 @@ private void simulateHCADStartDetector( } while (duration <= 60_000); } - private String createDetector(String datasetName, int intervalMinutes, RestClient client, String categoryField, long windowDelayInMins) - throws Exception { - Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - windowDelayInMins - ); - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"category_field\": [\"%s\"], " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - categoryField, - windowDelayInMins - ); - } - - request.setJsonEntity(requestBody); - Map response = entityAsMap(client.performRequest(request)); - String detectorId = (String) response.get("_id"); - Thread.sleep(1_000); - return detectorId; - } - - private List> getAnomalyWindows(String labalFileName) throws Exception { - JsonArray windows = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List> anomalies = new ArrayList<>(windows.size()); - for (int i = 0; i < windows.size(); i++) { - JsonArray window = windows.get(i).getAsJsonArray(); - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); - Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); - anomalies.add(new SimpleEntry<>(begin, end)); - } - return anomalies; - } - - private void bulkIndexTrainData(String datasetName, List data, int trainTestSplit, RestClient client, String categoryField) - throws Exception { - Request request = new Request("PUT", datasetName); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," - + "\"%s\": { \"type\": \"keyword\"} } } }", - categoryField - ); - } - - request.setJsonEntity(requestBody); - setWarningHandler(request, false); - client.performRequest(request); - Thread.sleep(1_000); - - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = 0; i < trainTestSplit; i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); - } - - private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = trainTestSplit; i < data.size(); i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(data.size(), datasetName, client); - } - - private void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { - int maxWaitCycles = 3; - do { - Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); - request - .setJsonEntity( - String - .format( - Locale.ROOT, - "{\"query\": {" - + " \"match_all\": {}" - + " }," - + " \"size\": 1," - + " \"sort\": [" - + " {" - + " \"timestamp\": {" - + " \"order\": \"desc\"" - + " }" - + " }" - + " ]}" - ) - ); - // Make sure all of the test data has been ingested - // Expected response: - // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} - Response response = client.performRequest(request); - JsonObject json = JsonParser - .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) - .getAsJsonObject(); - JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); - if (hits != null - && hits.size() == 1 - && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { - break; - } else { - request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); - client.performRequest(request); - } - Thread.sleep(1_000); - } while (maxWaitCycles-- >= 0); - } - - private void setWarningHandler(Request request, boolean strictDeprecationMode) { - RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); - options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); - request.setOptions(options.build()); - } - - private List getData(String datasetFileName) throws Exception { - JsonArray jsonArray = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List list = new ArrayList<>(jsonArray.size()); - jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); - return list; - } - - private Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { - try { - Request request = new Request( - "POST", - String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) - ); - request - .setJsonEntity( - String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) - ); - return entityAsMap(client.performRequest(request)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * In real time AD, we mute a node for a detector if that node keeps returning - * ResourceNotFoundException (5 times in a row). This is a problem for batch mode - * testing as we issue a large amount of requests quickly. Due to the speed, we - * won't be able to finish cold start before the ResourceNotFoundException mutes - * a node. Since our test case has only one node, there is no other nodes to fall - * back on. Here we disable such fault tolerance by setting max retries before - * muting to a large number and the actual wait time during muting to 0. - * - * @throws IOException when failing to create http request body - */ - private void disableResourceNotFoundFaultTolerence() throws IOException { - XContentBuilder settingCommand = JsonXContent.contentBuilder(); - - settingCommand.startObject(); - settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); - settingCommand.endObject(); - settingCommand.endObject(); - Request request = new Request("PUT", "/_cluster/settings"); - request.setJsonEntity(Strings.toString(settingCommand)); - - adminClient().performRequest(request); - } - public void testValidationIntervalRecommendation() throws Exception { RestClient client = client(); long recDetectorIntervalMillis = 180000; diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java new file mode 100644 index 000000000..710c8b6ec --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java @@ -0,0 +1,230 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; + +import java.io.File; +import java.io.FileReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.RestClient; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class SingleStreamModelPerfIT extends AbstractSyntheticDataTest { + protected static final Logger LOG = (Logger) LogManager.getLogger(SingleStreamModelPerfIT.class); + + public void testDataset() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); + } + } + + private void verifyAnomaly( + String datasetName, + int intervalMinutes, + int trainTestSplit, + int shingleSize, + double minPrecision, + double minRecall, + double maxError + ) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + + List data = getData(dataFileName); + List> anomalies = getAnomalyWindows(labelFileName); + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); + // single-stream detector can use window delay 0 here because we give the run api the actual data time + String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); + simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + bulkIndexTestData(data, datasetName, trainTestSplit, client); + double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); + verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); + } + + private void verifyTestResults( + double[] testResults, + List> anomalies, + double minPrecision, + double minRecall, + double maxError + ) { + + double positives = testResults[0]; + double truePositives = testResults[1]; + double positiveAnomalies = testResults[2]; + double errors = testResults[3]; + + // precision = predicted anomaly points that are true / predicted anomaly points + double precision = positives > 0 ? truePositives / positives : 1; + assertTrue(precision >= minPrecision); + + // recall = windows containing predicted anomaly points / total anomaly windows + double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; + assertTrue(recall >= minRecall); + + assertTrue(errors <= maxError); + LOG.info("Precision: {}, Window recall: {}", precision, recall); + } + + private double[] getTestResults( + String detectorId, + List data, + int trainTestSplit, + int intervalMinutes, + List> anomalies, + RestClient client + ) throws Exception { + + double positives = 0; + double truePositives = 0; + Set positiveAnomalies = new HashSet<>(); + double errors = 0; + for (int i = trainTestSplit; i < data.size(); i++) { + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + Map response = getDetectionResult(detectorId, begin, end, client); + double anomalyGrade = (double) response.get("anomalyGrade"); + if (anomalyGrade > 0) { + positives++; + int result = isAnomaly(begin, anomalies); + if (result != -1) { + truePositives++; + positiveAnomalies.add(result); + } + } + } catch (Exception e) { + errors++; + logger.error("failed to get detection results", e); + } + } + return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; + } + + private List> getAnomalyWindows(String labalFileName) throws Exception { + JsonArray windows = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List> anomalies = new ArrayList<>(windows.size()); + for (int i = 0; i < windows.size(); i++) { + JsonArray window = windows.get(i).getAsJsonArray(); + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); + Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); + anomalies.add(new SimpleEntry<>(begin, end)); + } + return anomalies; + } + + /** + * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) + * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateSingleStreamStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); + + Instant begin = null; + Instant end = null; + for (int i = 0; i < shingleSize; i++) { + begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); + end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + } + // It takes time to wait for model initialization + long startTime = System.currentTimeMillis(); + do { + try { + Thread.sleep(5_000); + getDetectionResult(detectorId, begin, end, client); + break; + } catch (Exception e) { + long duration = System.currentTimeMillis() - startTime; + // we wait at most 60 secs + if (duration > 60_000) { + throw new RuntimeException(e); + } + } + } while (true); + } + + private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = trainTestSplit; i < data.size(); i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); + } + + private int isAnomaly(Instant time, List> labels) { + for (int i = 0; i < labels.size(); i++) { + Entry window = labels.get(i); + if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { + return i; + } + } + return -1; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java new file mode 100644 index 000000000..5159fe8ba --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -0,0 +1,256 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.Interpolator; +import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class AbstractCosineDataTest extends AbstractADTest { + int numMinSamples; + String modelId; + String entityName; + String detectorId; + ModelState modelState; + Clock clock; + float priority; + EntityColdStarter entityColdStarter; + NodeStateManager stateManager; + SearchFeatureDao searchFeatureDao; + Interpolator interpolator; + CheckpointDao checkpoint; + FeatureManager featureManager; + Settings settings; + ThreadPool threadPool; + AtomicBoolean released; + Runnable releaseSemaphore; + ActionListener listener; + CountDownLatch inProgressLatch; + CheckpointWriteWorker checkpointWriteQueue; + Entity entity; + AnomalyDetector detector; + long rcfSeed; + ModelManager modelManager; + ClientUtil clientUtil; + ClusterService clusterService; + ClusterSettings clusterSettings; + DiscoveryNode discoveryNode; + Set> nodestateSetting; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + settings = Settings.EMPTY; + + Client client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + when(clock.millis()).thenReturn(1602401500000L); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(CHECKPOINT_SAVING_FREQ); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = + new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); + interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + + searchFeatureDao = mock(SearchFeatureDao.class); + checkpoint = mock(CheckpointDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + rcfSeed = 2051L; + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); + + detectorId = "123"; + modelId = "123_entity_abc"; + entityName = "abc"; + priority = 0.3f; + entity = Entity.createSingleAttributeEntity("field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + } + + protected void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); + } + + public int searchInsert(long[] timestamps, long target) { + int pivot, left = 0, right = timestamps.length - 1; + while (left <= right) { + pivot = left + (right - left) / 2; + if (timestamps[pivot] == target) + return pivot; + if (target < timestamps[pivot]) + right = pivot - 1; + else + left = pivot + 1; + } + return left; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index ddea2510b..ebf0d2d8b 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -19,9 +19,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.IOException; import java.time.Clock; @@ -31,7 +28,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; @@ -39,48 +35,29 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AbstractADTest; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.Interpolator; -import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.EnabledSetting; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; -import org.opensearch.test.ClusterServiceUtils; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; import test.org.opensearch.ad.util.LabelledAnomalyGenerator; import test.org.opensearch.ad.util.MLUtil; @@ -91,33 +68,7 @@ import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; -public class EntityColdStarterTests extends AbstractADTest { - int numMinSamples; - String modelId; - String entityName; - String detectorId; - ModelState modelState; - Clock clock; - float priority; - EntityColdStarter entityColdStarter; - NodeStateManager stateManager; - SearchFeatureDao searchFeatureDao; - Interpolator interpolator; - CheckpointDao checkpoint; - FeatureManager featureManager; - Settings settings; - ThreadPool threadPool; - AtomicBoolean released; - Runnable releaseSemaphore; - ActionListener listener; - CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; - Entity entity; - AnomalyDetector detector; - long rcfSeed; - ModelManager modelManager; - ClientUtil clientUtil; - ClusterService clusterService; +public class EntityColdStarterTests extends AbstractCosineDataTest { @BeforeClass public static void initOnce() { @@ -136,158 +87,12 @@ public static void clearOnce() { EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); } - @SuppressWarnings("unchecked") - @Override - public void setUp() throws Exception { - super.setUp(); - numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; - - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); - - threadPool = mock(ThreadPool.class); - setUpADThreadPool(threadPool); - - settings = Settings.EMPTY; - - Client client = mock(Client.class); - clientUtil = mock(ClientUtil.class); - - detector = TestHelpers.AnomalyDetectorBuilder - .newInstance() - .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) - .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) - .build(); - when(clock.millis()).thenReturn(1602401500000L); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - - return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); - - Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); - nodestateSetting.add(CHECKPOINT_SAVING_FREQ); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); - - DiscoveryNode discoveryNode = new DiscoveryNode( - "node1", - OpenSearchTestCase.buildNewFakeTransportAddress(), - Collections.emptyMap(), - DiscoveryNodeRole.BUILT_IN_ROLES, - Version.CURRENT - ); - - clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - - stateManager = new NodeStateManager( - client, - xContentRegistry(), - settings, - clientUtil, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - - SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = - new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); - interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); - - searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); - - featureManager = new FeatureManager( - searchFeatureDao, - interpolator, - clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, - AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, - AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, - AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, - AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME - ); - - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - - rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( - clock, - threadPool, - stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, - numMinSamples, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, - searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - featureManager, - settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS - ); - EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); - - detectorId = "123"; - modelId = "123_entity_abc"; - entityName = "abc"; - priority = 0.3f; - entity = Entity.createSingleAttributeEntity("field", entityName); - - released = new AtomicBoolean(); - - inProgressLatch = new CountDownLatch(1); - releaseSemaphore = () -> { - released.set(true); - inProgressLatch.countDown(); - }; - listener = ActionListener.wrap(releaseSemaphore); - - modelManager = new ModelManager( - mock(CheckpointDao.class), - mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - entityColdStarter, - mock(FeatureManager.class), - mock(MemoryTracker.class), - settings, - clusterService - - ); - } - @Override public void tearDown() throws Exception { EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.FALSE); super.tearDown(); } - private void checkSemaphoreRelease() throws InterruptedException { - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - assertTrue(released.get()); - } - // train using samples directly public void testTrainUsingSamples() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(numMinSamples); @@ -757,7 +562,18 @@ private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold LOG.info("seed = " + seed); // create labelled data MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator - .getMultiDimData(dataSize + detector.getShingleSize() - 1, 50, 100, 5, seed, baseDimension, false, trainTestSplit, delta); + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + false + ); long[] timestamps = dataWithKeys.timestampsMs; double[][] data = dataWithKeys.data; when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); @@ -858,21 +674,6 @@ public int compare(Entry p1, Entry p2) { assertTrue("precision is " + prec, prec >= precisionThreshold); assertTrue("recall is " + recall, recall >= recallThreshold); - LOG.info("Interval {}, Precision: {}, recall: {}", detectorIntervalMins, prec, recall); - } - - public int searchInsert(long[] timestamps, long target) { - int pivot, left = 0, right = timestamps.length - 1; - while (left <= right) { - pivot = left + (right - left) / 2; - if (timestamps[pivot] == target) - return pivot; - if (target < timestamps[pivot]) - right = pivot - 1; - else - left = pivot + 1; - } - return left; } public void testAccuracyTenMinuteInterval() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java new file mode 100644 index 000000000..5d2849401 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -0,0 +1,342 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.lucene.tests.util.TimeUnits; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import com.google.common.collect.ImmutableList; + +@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) // rcf may be slow due to bounding box cache disabled +public class HCADModelPerfTests extends AbstractCosineDataTest { + + /** + * A template to perform precision/recall test by simulating HCAD logic with only one entity. + * + * @param detectorIntervalMins Detector interval + * @param precisionThreshold precision threshold + * @param recallThreshold recall threshold + * @param baseDimension the number of dimensions + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @throws Exception when failing to create anomaly detector or creating training data + */ + @SuppressWarnings("unchecked") + private void averageAccuracyTemplate( + int detectorIntervalMins, + float precisionThreshold, + float recallThreshold, + int baseDimension, + boolean anomalyIndependent + ) throws Exception { + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 10; + double prec = 0; + double recall = 0; + double totalPrec = 0; + double totalRecall = 0; + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + for (int z = 1; z <= numberOfTrials; z++) { + long seed = z; + LOG.info("seed = " + seed); + // recreate in each loop; otherwise, we will have heap overflow issue. + searchFeatureDao = mock(SearchFeatureDao.class); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + seed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + anomalyIndependent + ); + + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entity = Entity.createSingleAttributeEntity("field", entityName + z); + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + ModelState modelState = new ModelState<>( + model, + entity.getModelId(detectorId).get(), + detector.getDetectorId(), + ModelType.ENTITY.getName(), + clock, + priority + ); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getDetectorId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + totalPrec += prec; + totalRecall += recall; + modelState = null; + dataWithKeys = null; + reset(searchFeatureDao); + searchFeatureDao = null; + clusterService = null; + } + + double avgPrec = totalPrec / numberOfTrials; + double avgRecall = totalRecall / numberOfTrials; + LOG.info("{} features, Interval {}, Precision: {}, recall: {}", baseDimension, detectorIntervalMins, avgPrec, avgRecall); + assertTrue("average precision is " + avgPrec, avgPrec >= precisionThreshold); + assertTrue("average recall is " + avgRecall, avgRecall >= recallThreshold); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyDependent() throws Exception { + LOG.info("Anomalies are injected dependently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.4f, 0.3f, 4, false); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, false); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 1, false); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.4f, 0.3f, 4, false); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, false); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 1, false); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyIndependent() throws Exception { + LOG.info("Anomalies are injected independently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.3f, 0.1f, 4, true); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, true); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.3f, 0.4f, 1, true); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.2f, 0.1f, 4, true); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, true); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.3f, 0.4f, 1, true); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 33ecc247c..b41d43083 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -29,6 +29,7 @@ import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.io.entity.StringEntity; +import org.hamcrest.CoreMatchers; import org.junit.Assert; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.AnomalyDetectorRestTestCase; @@ -1185,7 +1186,7 @@ public void testDeleteAnomalyDetectorWhileRunning() throws Exception { new DetectionDateRange(now.minus(10, ChronoUnit.DAYS), now), client() ); - Assert.assertEquals(response.getStatusLine().toString(), "HTTP/1.1 200 OK"); + Assert.assertThat(response.getStatusLine().toString(), CoreMatchers.containsString("200 OK")); // Deleting detector should fail while its running Exception exception = expectThrows(IOException.class, () -> { deleteAnomalyDetector(detector.getDetectorId(), client()); }); diff --git a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java index f2ef3cc2d..f77c135fb 100644 --- a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java +++ b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java @@ -29,6 +29,7 @@ public class LabelledAnomalyGenerator { * @param useSlope whether to use slope in cosine data * @param historicalData the number of historical points relative to now * @param delta point interval + * @param anomalyIndependent whether anomalies in each dimension is generated independently * @return the labelled data */ public static MultiDimDataWithTime getMultiDimData( @@ -40,7 +41,8 @@ public static MultiDimDataWithTime getMultiDimData( int baseDimension, boolean useSlope, int historicalData, - int delta + int delta, + boolean anomalyIndependent ) { double[][] data = new double[num][]; long[] timestamps = new long[num]; @@ -66,14 +68,34 @@ public static MultiDimDataWithTime getMultiDimData( startEpochMs += delta; data[i] = new double[baseDimension]; double[] newChange = new double[baseDimension]; - for (int j = 0; j < baseDimension; j++) { - data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); - if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { - double factor = 5 * (1 + noiseprg.nextDouble()); - double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; - data[i][j] += newChange[j] = change; - changedTimestamps[i] = timestamps[i]; - changes[i] = newChange; + // decide whether we should inject anomalies at this point + // If we do this for each dimension, each dimension's anomalies + // are independent and will make it harder for RCF to detect anomalies. + // Doing it in point level will make each dimension's anomalies + // correlated. + if (anomalyIndependent) { + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } else { + boolean flag = (noiseprg.nextDouble() < 0.01); + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + // adding the condition < 0.3 so there is still some variance if all features have an anomaly or not + if (flag && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } } } }