diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..f01b60c9a --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,33 @@ +name: Build and Test Anomaly detection +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + +jobs: + Build-ad: + strategy: + matrix: + java: [17] + fail-fast: false + + name: Build and Test Anomaly detection Plugin + runs-on: ubuntu-latest + + steps: + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + # anomaly-detection + - name: Checkout AD + uses: actions/checkout@v2 + + - name: Build and Run Tests + run: | + ./gradlew ':test' --tests "org.opensearch.ad.ml.HCADModelPerfTests" -Dtests.seed=2AEBDBBAE75AC5E0 -Dtests.security.manager=false -Dtests.locale=es-CU -Dtests.timezone=Chile/EasterIsland -Dtest.logs=true -Dmodel-benchmark=true + ./gradlew integTest --tests "org.opensearch.ad.e2e.SingleStreamModelPerfIT" -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true -Dmodel-benchmark=true \ No newline at end of file diff --git a/.github/workflows/test_build_multi_platform.yml b/.github/workflows/test_build_multi_platform.yml index 50e86f785..ffc2aa8a3 100644 --- a/.github/workflows/test_build_multi_platform.yml +++ b/.github/workflows/test_build_multi_platform.yml @@ -60,7 +60,7 @@ jobs: ./gradlew assemble - name: Build and Run Tests run: | - ./gradlew build -Dtest.logs=true + ./gradlew build - name: Publish to Maven Local run: | ./gradlew publishToMavenLocal diff --git a/build.gradle b/build.gradle index cecc314ff..cf08bf7d5 100644 --- a/build.gradle +++ b/build.gradle @@ -139,7 +139,7 @@ configurations.all { if (it.state != Configuration.State.UNRESOLVED) return resolutionStrategy { force "joda-time:joda-time:${versions.joda}" - force "com.fasterxml.jackson.core:jackson-core:2.13.4" + force "com.fasterxml.jackson.core:jackson-core:2.14.0" force "commons-logging:commons-logging:${versions.commonslogging}" force "org.apache.httpcomponents:httpcore5:${versions.httpcore5}" force "commons-codec:commons-codec:${versions.commonscodec}" @@ -219,6 +219,12 @@ test { } include '**/*Tests.class' systemProperty 'tests.security.manager', 'false' + + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.ml.HCADModelPerfTests" + } + } } task integTest(type: RestIntegTestTask) { @@ -264,6 +270,12 @@ integTest { } } + if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { + filter { + excludeTestsMatching "org.opensearch.ad.e2e.SingleStreamModelPerfIT" + } + } + // The 'doFirst' delays till execution time. doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can @@ -664,9 +676,9 @@ dependencies { implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' // force Jackson version to avoid version conflict issue - implementation "com.fasterxml.jackson.core:jackson-core:2.13.4" - implementation "com.fasterxml.jackson.core:jackson-databind:2.13.4.2" - implementation "com.fasterxml.jackson.core:jackson-annotations:2.13.4" + implementation "com.fasterxml.jackson.core:jackson-core:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-databind:2.14.0" + implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.0" // used for serializing/deserializing rcf models. implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java b/src/test/java/org/opensearch/ad/ODFERestTestCase.java index cf89f5c85..44b3e1d2d 100644 --- a/src/test/java/org/opensearch/ad/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java @@ -11,6 +11,8 @@ package org.opensearch.ad; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; @@ -186,21 +188,18 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s .ofNullable(System.getProperty("password")) .orElseThrow(() -> new RuntimeException("password is missing")); BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider - .setCredentials( - new AuthScope(new HttpHost("localhost", 9200)), - new UsernamePasswordCredentials(userName, password.toCharArray()) - ); + final AuthScope anyScope = new AuthScope(null, -1); + credentialsProvider.setCredentials(anyScope, new UsernamePasswordCredentials(userName, password.toCharArray())); try { final TlsStrategy tlsStrategy = ClientTlsStrategyBuilder .create() - .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) - // disable the certificate since our testing cluster just uses the default security configuration .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) .build(); - final PoolingAsyncClientConnectionManager connectionManager = PoolingAsyncClientConnectionManagerBuilder .create() + .setMaxConnPerRoute(DEFAULT_MAX_CONN_PER_ROUTE) + .setMaxConnTotal(DEFAULT_MAX_CONN_TOTAL) .setTlsStrategy(tlsStrategy) .build(); return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(connectionManager); @@ -212,8 +211,12 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); final TimeValue socketTimeout = TimeValue .parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); - builder - .setRequestConfigCallback(conf -> conf.setResponseTimeout(Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())))); + builder.setRequestConfigCallback(conf -> { + Timeout timeout = Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())); + conf.setConnectTimeout(timeout); + conf.setResponseTimeout(timeout); + return conf; + }); if (settings.hasValue(CLIENT_PATH_PREFIX)) { builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); } diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java new file mode 100644 index 000000000..9458bc740 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -0,0 +1,242 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.ad.ODFERestTestCase; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class AbstractSyntheticDataTest extends ODFERestTestCase { + /** + * In real time AD, we mute a node for a detector if that node keeps returning + * ResourceNotFoundException (5 times in a row). This is a problem for batch mode + * testing as we issue a large amount of requests quickly. Due to the speed, we + * won't be able to finish cold start before the ResourceNotFoundException mutes + * a node. Since our test case has only one node, there is no other nodes to fall + * back on. Here we disable such fault tolerance by setting max retries before + * muting to a large number and the actual wait time during muting to 0. + * + * @throws IOException when failing to create http request body + */ + protected void disableResourceNotFoundFaultTolerence() throws IOException { + XContentBuilder settingCommand = JsonXContent.contentBuilder(); + + settingCommand.startObject(); + settingCommand.startObject("persistent"); + settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.endObject(); + settingCommand.endObject(); + Request request = new Request("PUT", "/_cluster/settings"); + request.setJsonEntity(Strings.toString(settingCommand)); + + adminClient().performRequest(request); + } + + protected List getData(String datasetFileName) throws Exception { + JsonArray jsonArray = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List list = new ArrayList<>(jsonArray.size()); + jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); + return list; + } + + protected Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { + try { + Request request = new Request( + "POST", + String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) + ); + request + .setJsonEntity( + String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) + ); + return entityAsMap(client.performRequest(request)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected void bulkIndexTrainData( + String datasetName, + List data, + int trainTestSplit, + RestClient client, + String categoryField + ) throws Exception { + Request request = new Request("PUT", datasetName); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoryField + ); + } + + request.setJsonEntity(requestBody); + setWarningHandler(request, false); + client.performRequest(request); + Thread.sleep(1_000); + + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = 0; i < trainTestSplit; i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); + } + + protected String createDetector( + String datasetName, + int intervalMinutes, + RestClient client, + String categoryField, + long windowDelayInMins + ) throws Exception { + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + windowDelayInMins + ); + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + categoryField, + windowDelayInMins + ); + } + + request.setJsonEntity(requestBody); + Map response = entityAsMap(client.performRequest(request)); + String detectorId = (String) response.get("_id"); + Thread.sleep(1_000); + return detectorId; + } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } + + protected void setWarningHandler(Request request, boolean strictDeprecationMode) { + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 145cf664b..0c2cce832 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -12,212 +12,38 @@ package org.opensearch.ad.e2e; import static org.opensearch.ad.TestHelpers.toHttpEntity; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.Charset; import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Instant; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; -import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Logger; -import org.opensearch.ad.ODFERestTestCase; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; -import org.opensearch.client.WarningsHandler; -import org.opensearch.common.Strings; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.support.XContentMapValues; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; -public class DetectionResultEvalutationIT extends ODFERestTestCase { +public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { protected static final Logger LOG = (Logger) LogManager.getLogger(DetectionResultEvalutationIT.class); - public void testDataset() throws Exception { - // TODO: this test case will run for a much longer time and timeout with security enabled - if (!isHttps()) { - disableResourceNotFoundFaultTolerence(); - verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); - } - } - - private void verifyAnomaly( - String datasetName, - int intervalMinutes, - int trainTestSplit, - int shingleSize, - double minPrecision, - double minRecall, - double maxError - ) throws Exception { - RestClient client = client(); - - String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); - String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); - - List data = getData(dataFileName); - List> anomalies = getAnomalyWindows(labelFileName); - - bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); - // single-stream detector can use window delay 0 here because we give the run api the actual data time - String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); - simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); - bulkIndexTestData(data, datasetName, trainTestSplit, client); - double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); - verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); - } - - private void verifyTestResults( - double[] testResults, - List> anomalies, - double minPrecision, - double minRecall, - double maxError - ) { - - double positives = testResults[0]; - double truePositives = testResults[1]; - double positiveAnomalies = testResults[2]; - double errors = testResults[3]; - - // precision = predicted anomaly points that are true / predicted anomaly points - double precision = positives > 0 ? truePositives / positives : 1; - assertTrue(precision >= minPrecision); - - // recall = windows containing predicted anomaly points / total anomaly windows - double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; - assertTrue(recall >= minRecall); - - assertTrue(errors <= maxError); - LOG.info("Precision: {}, Window recall: {}", precision, recall); - } - - private int isAnomaly(Instant time, List> labels) { - for (int i = 0; i < labels.size(); i++) { - Entry window = labels.get(i); - if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { - return i; - } - } - return -1; - } - - private double[] getTestResults( - String detectorId, - List data, - int trainTestSplit, - int intervalMinutes, - List> anomalies, - RestClient client - ) throws Exception { - - double positives = 0; - double truePositives = 0; - Set positiveAnomalies = new HashSet<>(); - double errors = 0; - for (int i = trainTestSplit; i < data.size(); i++) { - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); - Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - Map response = getDetectionResult(detectorId, begin, end, client); - double anomalyGrade = (double) response.get("anomalyGrade"); - if (anomalyGrade > 0) { - positives++; - int result = isAnomaly(begin, anomalies); - if (result != -1) { - truePositives++; - positiveAnomalies.add(result); - } - } - } catch (Exception e) { - errors++; - logger.error("failed to get detection results", e); - } - } - return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; - } - - /** - * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) - * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. - * @param detectorId Detector Id - * @param data Data in Json format - * @param trainTestSplit Training data size - * @param shingleSize Shingle size - * @param intervalMinutes Detector Interval - * @param client OpenSearch Client - * @throws Exception when failing to query/indexing from/to OpenSearch - */ - private void simulateSingleStreamStartDetector( - String detectorId, - List data, - int trainTestSplit, - int shingleSize, - int intervalMinutes, - RestClient client - ) throws Exception { - - Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); - - Instant begin = null; - Instant end = null; - for (int i = 0; i < shingleSize; i++) { - begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); - end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - getDetectionResult(detectorId, begin, end, client); - } catch (Exception e) {} - } - // It takes time to wait for model initialization - long startTime = System.currentTimeMillis(); - do { - try { - Thread.sleep(5_000); - getDetectionResult(detectorId, begin, end, client); - break; - } catch (Exception e) { - long duration = System.currentTimeMillis() - startTime; - // we wait at most 60 secs - if (duration > 60_000) { - throw new RuntimeException(e); - } - } - } while (true); - } - /** * Simulate starting the given HCAD detector. * @param detectorId Detector Id @@ -274,224 +100,6 @@ private void simulateHCADStartDetector( } while (duration <= 60_000); } - private String createDetector(String datasetName, int intervalMinutes, RestClient client, String categoryField, long windowDelayInMins) - throws Exception { - Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - windowDelayInMins - ); - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"category_field\": [\"%s\"], " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes, - categoryField, - windowDelayInMins - ); - } - - request.setJsonEntity(requestBody); - Map response = entityAsMap(client.performRequest(request)); - String detectorId = (String) response.get("_id"); - Thread.sleep(1_000); - return detectorId; - } - - private List> getAnomalyWindows(String labalFileName) throws Exception { - JsonArray windows = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List> anomalies = new ArrayList<>(windows.size()); - for (int i = 0; i < windows.size(); i++) { - JsonArray window = windows.get(i).getAsJsonArray(); - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); - Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); - anomalies.add(new SimpleEntry<>(begin, end)); - } - return anomalies; - } - - private void bulkIndexTrainData(String datasetName, List data, int trainTestSplit, RestClient client, String categoryField) - throws Exception { - Request request = new Request("PUT", datasetName); - String requestBody = null; - if (Strings.isEmpty(categoryField)) { - requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; - } else { - requestBody = String - .format( - Locale.ROOT, - "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," - + "\"%s\": { \"type\": \"keyword\"} } } }", - categoryField - ); - } - - request.setJsonEntity(requestBody); - setWarningHandler(request, false); - client.performRequest(request); - Thread.sleep(1_000); - - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = 0; i < trainTestSplit; i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); - } - - private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { - StringBuilder bulkRequestBuilder = new StringBuilder(); - for (int i = trainTestSplit; i < data.size(); i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - } - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); - waitAllSyncheticDataIngested(data.size(), datasetName, client); - } - - private void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { - int maxWaitCycles = 3; - do { - Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); - request - .setJsonEntity( - String - .format( - Locale.ROOT, - "{\"query\": {" - + " \"match_all\": {}" - + " }," - + " \"size\": 1," - + " \"sort\": [" - + " {" - + " \"timestamp\": {" - + " \"order\": \"desc\"" - + " }" - + " }" - + " ]}" - ) - ); - // Make sure all of the test data has been ingested - // Expected response: - // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} - Response response = client.performRequest(request); - JsonObject json = JsonParser - .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) - .getAsJsonObject(); - JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); - if (hits != null - && hits.size() == 1 - && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { - break; - } else { - request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); - client.performRequest(request); - } - Thread.sleep(1_000); - } while (maxWaitCycles-- >= 0); - } - - private void setWarningHandler(Request request, boolean strictDeprecationMode) { - RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); - options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); - request.setOptions(options.build()); - } - - private List getData(String datasetFileName) throws Exception { - JsonArray jsonArray = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) - .getAsJsonArray(); - List list = new ArrayList<>(jsonArray.size()); - jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); - return list; - } - - private Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { - try { - Request request = new Request( - "POST", - String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) - ); - request - .setJsonEntity( - String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) - ); - return entityAsMap(client.performRequest(request)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * In real time AD, we mute a node for a detector if that node keeps returning - * ResourceNotFoundException (5 times in a row). This is a problem for batch mode - * testing as we issue a large amount of requests quickly. Due to the speed, we - * won't be able to finish cold start before the ResourceNotFoundException mutes - * a node. Since our test case has only one node, there is no other nodes to fall - * back on. Here we disable such fault tolerance by setting max retries before - * muting to a large number and the actual wait time during muting to 0. - * - * @throws IOException when failing to create http request body - */ - private void disableResourceNotFoundFaultTolerence() throws IOException { - XContentBuilder settingCommand = JsonXContent.contentBuilder(); - - settingCommand.startObject(); - settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); - settingCommand.endObject(); - settingCommand.endObject(); - Request request = new Request("PUT", "/_cluster/settings"); - request.setJsonEntity(Strings.toString(settingCommand)); - - adminClient().performRequest(request); - } - public void testValidationIntervalRecommendation() throws Exception { RestClient client = client(); long recDetectorIntervalMillis = 180000; diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java new file mode 100644 index 000000000..710c8b6ec --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java @@ -0,0 +1,230 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.TestHelpers.toHttpEntity; + +import java.io.File; +import java.io.FileReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.ad.TestHelpers; +import org.opensearch.client.RestClient; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class SingleStreamModelPerfIT extends AbstractSyntheticDataTest { + protected static final Logger LOG = (Logger) LogManager.getLogger(SingleStreamModelPerfIT.class); + + public void testDataset() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); + } + } + + private void verifyAnomaly( + String datasetName, + int intervalMinutes, + int trainTestSplit, + int shingleSize, + double minPrecision, + double minRecall, + double maxError + ) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + + List data = getData(dataFileName); + List> anomalies = getAnomalyWindows(labelFileName); + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); + // single-stream detector can use window delay 0 here because we give the run api the actual data time + String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); + simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + bulkIndexTestData(data, datasetName, trainTestSplit, client); + double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); + verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); + } + + private void verifyTestResults( + double[] testResults, + List> anomalies, + double minPrecision, + double minRecall, + double maxError + ) { + + double positives = testResults[0]; + double truePositives = testResults[1]; + double positiveAnomalies = testResults[2]; + double errors = testResults[3]; + + // precision = predicted anomaly points that are true / predicted anomaly points + double precision = positives > 0 ? truePositives / positives : 1; + assertTrue(precision >= minPrecision); + + // recall = windows containing predicted anomaly points / total anomaly windows + double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; + assertTrue(recall >= minRecall); + + assertTrue(errors <= maxError); + LOG.info("Precision: {}, Window recall: {}", precision, recall); + } + + private double[] getTestResults( + String detectorId, + List data, + int trainTestSplit, + int intervalMinutes, + List> anomalies, + RestClient client + ) throws Exception { + + double positives = 0; + double truePositives = 0; + Set positiveAnomalies = new HashSet<>(); + double errors = 0; + for (int i = trainTestSplit; i < data.size(); i++) { + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + Map response = getDetectionResult(detectorId, begin, end, client); + double anomalyGrade = (double) response.get("anomalyGrade"); + if (anomalyGrade > 0) { + positives++; + int result = isAnomaly(begin, anomalies); + if (result != -1) { + truePositives++; + positiveAnomalies.add(result); + } + } + } catch (Exception e) { + errors++; + logger.error("failed to get detection results", e); + } + } + return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; + } + + private List> getAnomalyWindows(String labalFileName) throws Exception { + JsonArray windows = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List> anomalies = new ArrayList<>(windows.size()); + for (int i = 0; i < windows.size(); i++) { + JsonArray window = windows.get(i).getAsJsonArray(); + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); + Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); + anomalies.add(new SimpleEntry<>(begin, end)); + } + return anomalies; + } + + /** + * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) + * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateSingleStreamStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); + + Instant begin = null; + Instant end = null; + for (int i = 0; i < shingleSize; i++) { + begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); + end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + } + // It takes time to wait for model initialization + long startTime = System.currentTimeMillis(); + do { + try { + Thread.sleep(5_000); + getDetectionResult(detectorId, begin, end, client); + break; + } catch (Exception e) { + long duration = System.currentTimeMillis() - startTime; + // we wait at most 60 secs + if (duration > 60_000) { + throw new RuntimeException(e); + } + } + } while (true); + } + + private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = trainTestSplit; i < data.size(); i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); + } + + private int isAnomaly(Instant time, List> labels) { + for (int i = 0; i < labels.size(); i++) { + Entry window = labels.get(i); + if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { + return i; + } + } + return -1; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java new file mode 100644 index 000000000..5159fe8ba --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -0,0 +1,256 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.Interpolator; +import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; +import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class AbstractCosineDataTest extends AbstractADTest { + int numMinSamples; + String modelId; + String entityName; + String detectorId; + ModelState modelState; + Clock clock; + float priority; + EntityColdStarter entityColdStarter; + NodeStateManager stateManager; + SearchFeatureDao searchFeatureDao; + Interpolator interpolator; + CheckpointDao checkpoint; + FeatureManager featureManager; + Settings settings; + ThreadPool threadPool; + AtomicBoolean released; + Runnable releaseSemaphore; + ActionListener listener; + CountDownLatch inProgressLatch; + CheckpointWriteWorker checkpointWriteQueue; + Entity entity; + AnomalyDetector detector; + long rcfSeed; + ModelManager modelManager; + ClientUtil clientUtil; + ClusterService clusterService; + ClusterSettings clusterSettings; + DiscoveryNode discoveryNode; + Set> nodestateSetting; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + settings = Settings.EMPTY; + + Client client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + when(clock.millis()).thenReturn(1602401500000L); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(CHECKPOINT_SAVING_FREQ); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = + new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); + interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + + searchFeatureDao = mock(SearchFeatureDao.class); + checkpoint = mock(CheckpointDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + rcfSeed = 2051L; + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); + + detectorId = "123"; + modelId = "123_entity_abc"; + entityName = "abc"; + priority = 0.3f; + entity = Entity.createSingleAttributeEntity("field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + } + + protected void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); + } + + public int searchInsert(long[] timestamps, long target) { + int pivot, left = 0, right = timestamps.length - 1; + while (left <= right) { + pivot = left + (right - left) / 2; + if (timestamps[pivot] == target) + return pivot; + if (target < timestamps[pivot]) + right = pivot - 1; + else + left = pivot + 1; + } + return left; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index ddea2510b..ebf0d2d8b 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -19,9 +19,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.IOException; import java.time.Clock; @@ -31,7 +28,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; @@ -39,48 +35,29 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AbstractADTest; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.dataprocessor.IntegerSensitiveSingleFeatureLinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.Interpolator; -import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; -import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.EnabledSetting; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; -import org.opensearch.test.ClusterServiceUtils; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; import test.org.opensearch.ad.util.LabelledAnomalyGenerator; import test.org.opensearch.ad.util.MLUtil; @@ -91,33 +68,7 @@ import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; -public class EntityColdStarterTests extends AbstractADTest { - int numMinSamples; - String modelId; - String entityName; - String detectorId; - ModelState modelState; - Clock clock; - float priority; - EntityColdStarter entityColdStarter; - NodeStateManager stateManager; - SearchFeatureDao searchFeatureDao; - Interpolator interpolator; - CheckpointDao checkpoint; - FeatureManager featureManager; - Settings settings; - ThreadPool threadPool; - AtomicBoolean released; - Runnable releaseSemaphore; - ActionListener listener; - CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; - Entity entity; - AnomalyDetector detector; - long rcfSeed; - ModelManager modelManager; - ClientUtil clientUtil; - ClusterService clusterService; +public class EntityColdStarterTests extends AbstractCosineDataTest { @BeforeClass public static void initOnce() { @@ -136,158 +87,12 @@ public static void clearOnce() { EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); } - @SuppressWarnings("unchecked") - @Override - public void setUp() throws Exception { - super.setUp(); - numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; - - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); - - threadPool = mock(ThreadPool.class); - setUpADThreadPool(threadPool); - - settings = Settings.EMPTY; - - Client client = mock(Client.class); - clientUtil = mock(ClientUtil.class); - - detector = TestHelpers.AnomalyDetectorBuilder - .newInstance() - .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) - .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) - .build(); - when(clock.millis()).thenReturn(1602401500000L); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - - return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); - - Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); - nodestateSetting.add(CHECKPOINT_SAVING_FREQ); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); - - DiscoveryNode discoveryNode = new DiscoveryNode( - "node1", - OpenSearchTestCase.buildNewFakeTransportAddress(), - Collections.emptyMap(), - DiscoveryNodeRole.BUILT_IN_ROLES, - Version.CURRENT - ); - - clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - - stateManager = new NodeStateManager( - client, - xContentRegistry(), - settings, - clientUtil, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - - SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = - new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); - interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); - - searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); - - featureManager = new FeatureManager( - searchFeatureDao, - interpolator, - clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, - AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, - AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, - AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, - AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME - ); - - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - - rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( - clock, - threadPool, - stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, - numMinSamples, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, - searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - featureManager, - settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS - ); - EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); - - detectorId = "123"; - modelId = "123_entity_abc"; - entityName = "abc"; - priority = 0.3f; - entity = Entity.createSingleAttributeEntity("field", entityName); - - released = new AtomicBoolean(); - - inProgressLatch = new CountDownLatch(1); - releaseSemaphore = () -> { - released.set(true); - inProgressLatch.countDown(); - }; - listener = ActionListener.wrap(releaseSemaphore); - - modelManager = new ModelManager( - mock(CheckpointDao.class), - mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - entityColdStarter, - mock(FeatureManager.class), - mock(MemoryTracker.class), - settings, - clusterService - - ); - } - @Override public void tearDown() throws Exception { EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.FALSE); super.tearDown(); } - private void checkSemaphoreRelease() throws InterruptedException { - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - assertTrue(released.get()); - } - // train using samples directly public void testTrainUsingSamples() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(numMinSamples); @@ -757,7 +562,18 @@ private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold LOG.info("seed = " + seed); // create labelled data MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator - .getMultiDimData(dataSize + detector.getShingleSize() - 1, 50, 100, 5, seed, baseDimension, false, trainTestSplit, delta); + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + false + ); long[] timestamps = dataWithKeys.timestampsMs; double[][] data = dataWithKeys.data; when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); @@ -858,21 +674,6 @@ public int compare(Entry p1, Entry p2) { assertTrue("precision is " + prec, prec >= precisionThreshold); assertTrue("recall is " + recall, recall >= recallThreshold); - LOG.info("Interval {}, Precision: {}, recall: {}", detectorIntervalMins, prec, recall); - } - - public int searchInsert(long[] timestamps, long target) { - int pivot, left = 0, right = timestamps.length - 1; - while (left <= right) { - pivot = left + (right - left) / 2; - if (timestamps[pivot] == target) - return pivot; - if (target < timestamps[pivot]) - right = pivot - 1; - else - left = pivot + 1; - } - return left; } public void testAccuracyTenMinuteInterval() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java new file mode 100644 index 000000000..5d2849401 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -0,0 +1,342 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.lucene.tests.util.TimeUnits; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import com.google.common.collect.ImmutableList; + +@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) // rcf may be slow due to bounding box cache disabled +public class HCADModelPerfTests extends AbstractCosineDataTest { + + /** + * A template to perform precision/recall test by simulating HCAD logic with only one entity. + * + * @param detectorIntervalMins Detector interval + * @param precisionThreshold precision threshold + * @param recallThreshold recall threshold + * @param baseDimension the number of dimensions + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @throws Exception when failing to create anomaly detector or creating training data + */ + @SuppressWarnings("unchecked") + private void averageAccuracyTemplate( + int detectorIntervalMins, + float precisionThreshold, + float recallThreshold, + int baseDimension, + boolean anomalyIndependent + ) throws Exception { + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 10; + double prec = 0; + double recall = 0; + double totalPrec = 0; + double totalRecall = 0; + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + for (int z = 1; z <= numberOfTrials; z++) { + long seed = z; + LOG.info("seed = " + seed); + // recreate in each loop; otherwise, we will have heap overflow issue. + searchFeatureDao = mock(SearchFeatureDao.class); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + seed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + anomalyIndependent + ); + + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entity = Entity.createSingleAttributeEntity("field", entityName + z); + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + ModelState modelState = new ModelState<>( + model, + entity.getModelId(detectorId).get(), + detector.getDetectorId(), + ModelType.ENTITY.getName(), + clock, + priority + ); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getDetectorId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + totalPrec += prec; + totalRecall += recall; + modelState = null; + dataWithKeys = null; + reset(searchFeatureDao); + searchFeatureDao = null; + clusterService = null; + } + + double avgPrec = totalPrec / numberOfTrials; + double avgRecall = totalRecall / numberOfTrials; + LOG.info("{} features, Interval {}, Precision: {}, recall: {}", baseDimension, detectorIntervalMins, avgPrec, avgRecall); + assertTrue("average precision is " + avgPrec, avgPrec >= precisionThreshold); + assertTrue("average recall is " + avgRecall, avgRecall >= recallThreshold); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyDependent() throws Exception { + LOG.info("Anomalies are injected dependently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.4f, 0.3f, 4, false); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, false); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 1, false); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.4f, 0.3f, 4, false); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, false); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 1, false); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyIndependent() throws Exception { + LOG.info("Anomalies are injected independently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.3f, 0.1f, 4, true); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, true); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.3f, 0.4f, 1, true); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.2f, 0.1f, 4, true); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, true); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.3f, 0.4f, 1, true); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 33ecc247c..b41d43083 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -29,6 +29,7 @@ import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.io.entity.StringEntity; +import org.hamcrest.CoreMatchers; import org.junit.Assert; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.AnomalyDetectorRestTestCase; @@ -1185,7 +1186,7 @@ public void testDeleteAnomalyDetectorWhileRunning() throws Exception { new DetectionDateRange(now.minus(10, ChronoUnit.DAYS), now), client() ); - Assert.assertEquals(response.getStatusLine().toString(), "HTTP/1.1 200 OK"); + Assert.assertThat(response.getStatusLine().toString(), CoreMatchers.containsString("200 OK")); // Deleting detector should fail while its running Exception exception = expectThrows(IOException.class, () -> { deleteAnomalyDetector(detector.getDetectorId(), client()); }); diff --git a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java index f2ef3cc2d..f77c135fb 100644 --- a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java +++ b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java @@ -29,6 +29,7 @@ public class LabelledAnomalyGenerator { * @param useSlope whether to use slope in cosine data * @param historicalData the number of historical points relative to now * @param delta point interval + * @param anomalyIndependent whether anomalies in each dimension is generated independently * @return the labelled data */ public static MultiDimDataWithTime getMultiDimData( @@ -40,7 +41,8 @@ public static MultiDimDataWithTime getMultiDimData( int baseDimension, boolean useSlope, int historicalData, - int delta + int delta, + boolean anomalyIndependent ) { double[][] data = new double[num][]; long[] timestamps = new long[num]; @@ -66,14 +68,34 @@ public static MultiDimDataWithTime getMultiDimData( startEpochMs += delta; data[i] = new double[baseDimension]; double[] newChange = new double[baseDimension]; - for (int j = 0; j < baseDimension; j++) { - data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); - if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { - double factor = 5 * (1 + noiseprg.nextDouble()); - double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; - data[i][j] += newChange[j] = change; - changedTimestamps[i] = timestamps[i]; - changes[i] = newChange; + // decide whether we should inject anomalies at this point + // If we do this for each dimension, each dimension's anomalies + // are independent and will make it harder for RCF to detect anomalies. + // Doing it in point level will make each dimension's anomalies + // correlated. + if (anomalyIndependent) { + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } else { + boolean flag = (noiseprg.nextDouble() < 0.01); + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + // adding the condition < 0.3 so there is still some variance if all features have an anomaly or not + if (flag && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } } } }