diff --git a/build.gradle b/build.gradle index 468bffe3..469c4d38 100644 --- a/build.gradle +++ b/build.gradle @@ -12,6 +12,7 @@ buildscript { opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") isSnapshot = "true" == System.getProperty("build.snapshot", "true") buildVersionQualifier = System.getProperty("build.version_qualifier", "") + kotlin_version = System.getProperty("kotlin.version", "1.8.21") } repositories { @@ -117,6 +118,7 @@ dependencies { implementation fileTree(dir: adJarDirectory, include: ["opensearch-time-series-analytics-${version}.jar"]) implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) compileOnly "org.opensearch:common-utils:${version}" + compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${version}" diff --git a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java new file mode 100644 index 00000000..3ade5b33 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.GetAlertsRequest; +import org.opensearch.commons.alerting.action.GetAlertsResponse; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchAlertsTool.TYPE) +public class SearchAlertsTool implements Tool { + public static final String TYPE = "SearchAlertsTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to search alerts."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String type; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAlertsTool(Client client) { + this.client = client; + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc"); + final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword"); + final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) + ? Integer.parseInt(parameters.get("size")) + : 20; + final int startIndex = parameters.containsKey("startIndex") && StringUtils.isNumeric(parameters.get("startIndex")) + ? Integer.parseInt(parameters.get("startIndex")) + : 0; + final String searchString = parameters.getOrDefault("searchString", null); + + // not exposing "missing" from the table, using default of null + final Table table = new Table(tableSortOrder, tableSortString, null, tableSize, startIndex, searchString); + + final String severityLevel = parameters.getOrDefault("severityLevel", "ALL"); + final String alertState = parameters.getOrDefault("alertState", "ALL"); + final String monitorId = parameters.getOrDefault("monitorId", null); + final String alertIndex = parameters.getOrDefault("alertIndex", null); + @SuppressWarnings("unchecked") + final List monitorIds = parameters.containsKey("monitorIds") + ? gson.fromJson(parameters.get("monitorIds"), List.class) + : null; + @SuppressWarnings("unchecked") + final List workflowIds = parameters.containsKey("workflowIds") + ? gson.fromJson(parameters.get("workflowIds"), List.class) + : null; + @SuppressWarnings("unchecked") + final List alertIds = parameters.containsKey("alertIds") ? gson.fromJson(parameters.get("alertIds"), List.class) : null; + + GetAlertsRequest getAlertsRequest = new GetAlertsRequest( + table, + severityLevel, + alertState, + monitorId, + alertIndex, + monitorIds, + workflowIds, + alertIds + ); + + // create response listener + // stringify the response, may change to a standard format in the future + ActionListener getAlertsListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + sb.append("Alerts=["); + for (Alert alert : response.getAlerts()) { + sb.append(alert.toString()); + } + sb.append("]"); + sb.append("TotalAlerts=").append(response.getTotalAlerts()); + listener.onResponse((T) sb.toString()); + }, e -> { + log.error("Failed to search alerts.", e); + listener.onFailure(e); + }); + + // execute the search + AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, getAlertsRequest, getAlertsListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Factory for the {@link SearchAlertsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAlertsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public SearchAlertsTool create(Map map) { + return new SearchAlertsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java index 357668c9..de397521 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java @@ -30,7 +30,9 @@ import lombok.Getter; import lombok.Setter; +import lombok.extern.log4j.Log4j2; +@Log4j2 @ToolAnnotation(SearchAnomalyDetectorsTool.TYPE) public class SearchAnomalyDetectorsTool implements Tool { public static final String TYPE = "SearchAnomalyDetectorsTool"; @@ -140,7 +142,10 @@ public void run(Map parameters, ActionListener listener) sb.append("]"); sb.append("TotalAnomalyDetectors=").append(response.getHits().getTotalHits().value); listener.onResponse((T) sb.toString()); - }, e -> { listener.onFailure(e); }); + }, e -> { + log.error("Failed to search anomaly detectors.", e); + listener.onFailure(e); + }); adClient.searchAnomalyDetectors(searchDetectorRequest, searchDetectorListener); } diff --git a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java new file mode 100644 index 00000000..21975080 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java @@ -0,0 +1,245 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.GetMonitorRequest; +import org.opensearch.commons.alerting.action.GetMonitorResponse; +import org.opensearch.commons.alerting.action.SearchMonitorRequest; +import org.opensearch.commons.alerting.model.Monitor; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchMonitorsTool.TYPE) +public class SearchMonitorsTool implements Tool { + public static final String TYPE = "SearchMonitorsTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to search alerting monitors."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String type; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchMonitorsTool(Client client) { + this.client = client; + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of monitors (only name and ID attached), and + // number of total monitors. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String monitorId = parameters.getOrDefault("monitorId", null); + final String monitorName = parameters.getOrDefault("monitorName", null); + final String monitorNamePattern = parameters.getOrDefault("monitorNamePattern", null); + final Boolean enabled = parameters.containsKey("enabled") ? Boolean.parseBoolean(parameters.get("enabled")) : null; + final Boolean hasTriggers = parameters.containsKey("hasTriggers") ? Boolean.parseBoolean(parameters.get("hasTriggers")) : null; + final String indices = parameters.getOrDefault("indices", null); + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = "asc".equalsIgnoreCase(sortOrderStr) ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "monitor.name.keyword"); + final int size = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) + ? Integer.parseInt(parameters.get("size")) + : 20; + final int startIndex = parameters.containsKey("startIndex") && StringUtils.isNumeric(parameters.get("startIndex")) + ? Integer.parseInt(parameters.get("startIndex")) + : 0; + + // If a monitor ID is specified, all other params will be ignored. Simply return the monitor details based on that ID + // via the get monitor transport action + if (monitorId != null) { + GetMonitorRequest getMonitorRequest = new GetMonitorRequest(monitorId, 1L, RestRequest.Method.GET, null); + ActionListener getMonitorListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + Monitor monitor = response.getMonitor(); + if (monitor != null) { + sb.append("Monitors=["); + sb.append("{"); + sb.append("id=").append(monitor.getId()).append(","); + sb.append("name=").append(monitor.getName()); + sb.append("}]"); + sb.append("TotalMonitors=1"); + } else { + sb.append("Monitors=[]TotalMonitors=0"); + } + listener.onResponse((T) sb.toString()); + }, e -> { + log.error("Failed to search monitors.", e); + listener.onFailure(e); + }); + AlertingPluginInterface.INSTANCE.getMonitor((NodeClient) client, getMonitorRequest, getMonitorListener); + } else { + List mustList = new ArrayList(); + if (monitorName != null) { + mustList.add(new TermQueryBuilder("monitor.name.keyword", monitorName)); + } + if (monitorNamePattern != null) { + mustList.add(new WildcardQueryBuilder("monitor.name.keyword", monitorNamePattern)); + } + if (enabled != null) { + mustList.add(new TermQueryBuilder("monitor.enabled", enabled)); + } + if (hasTriggers != null) { + NestedQueryBuilder nestedTriggerQuery = new NestedQueryBuilder( + "monitor.triggers", + new ExistsQueryBuilder("monitor.triggers"), + ScoreMode.None + ); + + BoolQueryBuilder triggerQuery = new BoolQueryBuilder(); + if (hasTriggers) { + triggerQuery.must(nestedTriggerQuery); + } else { + triggerQuery.mustNot(nestedTriggerQuery); + } + mustList.add(triggerQuery); + } + if (indices != null) { + mustList + .add( + new NestedQueryBuilder( + "monitor.inputs", + new WildcardQueryBuilder("monitor.inputs.search.indices", indices), + ScoreMode.None + ) + ); + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + SearchMonitorRequest searchMonitorRequest = new SearchMonitorRequest(new SearchRequest().source(searchSourceBuilder)); + + ActionListener searchMonitorListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + SearchHit[] hits = response.getHits().getHits(); + sb.append("Monitors=["); + for (SearchHit hit : hits) { + sb.append("{"); + sb.append("id=").append(hit.getId()).append(","); + sb.append("name=").append(hit.getSourceAsMap().get("name")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalMonitors=").append(response.getHits().getTotalHits().value); + listener.onResponse((T) sb.toString()); + }, e -> { + log.error("Failed to search monitors.", e); + listener.onFailure(e); + }); + AlertingPluginInterface.INSTANCE.searchMonitors((NodeClient) client, searchMonitorRequest, searchMonitorListener); + } + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Factory for the {@link SearchMonitorsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchMonitorsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public SearchMonitorsTool create(Map map) { + return new SearchMonitorsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java new file mode 100644 index 00000000..ca1f8b99 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java @@ -0,0 +1,198 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.action.GetAlertsResponse; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; + +public class SearchAlertsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAlertsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("searchString", "foo"); + } + + @Test + public void testRunWithNoAlerts() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + GetAlertsResponse getAlertsResponse = new GetAlertsResponse(Collections.emptyList(), 0); + String expectedResponseStr = "Alerts=[]TotalAlerts=0"; + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getAlertsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(nonEmptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithAlerts() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + Alert alert1 = new Alert( + "alert-id-1", + 1234, + 1, + "monitor-id", + "workflow-id", + "workflow-name", + "monitor-name", + 1234, + null, + "trigger-id", + "trigger-name", + Collections.emptyList(), + Collections.emptyList(), + Alert.State.ACKNOWLEDGED, + Instant.now(), + null, + null, + null, + null, + Collections.emptyList(), + "test-severity", + Collections.emptyList(), + null, + null, + Collections.emptyList() + ); + Alert alert2 = new Alert( + "alert-id-2", + 1234, + 1, + "monitor-id", + "workflow-id", + "workflow-name", + "monitor-name", + 1234, + null, + "trigger-id", + "trigger-name", + Collections.emptyList(), + Collections.emptyList(), + Alert.State.ACKNOWLEDGED, + Instant.now(), + null, + null, + null, + null, + Collections.emptyList(), + "test-severity", + Collections.emptyList(), + null, + null, + Collections.emptyList() + ); + List mockAlerts = List.of(alert1, alert2); + + GetAlertsResponse getAlertsResponse = new GetAlertsResponse(mockAlerts, mockAlerts.size()); + String expectedResponseStr = new StringBuilder() + .append("Alerts=[") + .append(alert1.toString()) + .append(alert2.toString()) + .append("]TotalAlerts=2") + .toString(); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getAlertsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(nonEmptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("sortOrder", "asc"); + validParams.put("sortString", "foo.bar"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + validParams.put("searchString", "foo"); + validParams.put("severityLevel", "ALL"); + validParams.put("alertState", "ALL"); + validParams.put("monitorId", "foo"); + validParams.put("alertIndex", "foo"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorIds", "[foo]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("workflowIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("workflowIds", "[foo]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("alertIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("alertIds", "[foo]"), listener)); + } + + @Test + public void testValidate() { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAlertsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java new file mode 100644 index 00000000..37bc960f --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java @@ -0,0 +1,256 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.alerting.action.GetMonitorResponse; +import org.opensearch.commons.alerting.model.CronSchedule; +import org.opensearch.commons.alerting.model.DataSources; +import org.opensearch.commons.alerting.model.Monitor; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class SearchMonitorsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + private Map monitorIdParams; + + private Monitor testMonitor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchMonitorsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("monitorName", "foo"); + monitorIdParams = Map.of("monitorId", "foo"); + testMonitor = new Monitor( + "monitor-1-id", + 0L, + "monitor-1", + true, + new CronSchedule("31 * * * *", ZoneId.of("Asia/Kolkata"), null), + Instant.now(), + Instant.now(), + Monitor.MonitorType.QUERY_LEVEL_MONITOR, + new User("test-user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + 0, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyMap(), + new DataSources(), + "" + ); + } + + @Test + public void testRunWithNoMonitors() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchHit[] hits = new SearchHit[0]; + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getMonitorsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String.format("Monitors=[]TotalMonitors=%d", hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithMonitorId() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + GetMonitorResponse getMonitorResponse = new GetMonitorResponse( + testMonitor.getId(), + 1L, + 2L, + 0L, + testMonitor, + Collections.emptyList() + ); + String expectedResponseStr = String + .format("Monitors=[{id=%s,name=%s}]TotalMonitors=%d", testMonitor.getId(), testMonitor.getName(), 1); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getMonitorResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(monitorIdParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithMonitorIdNotFound() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + GetMonitorResponse responseWithNullMonitor = new GetMonitorResponse(testMonitor.getId(), 1L, 2L, 0L, null, Collections.emptyList()); + String expectedResponseStr = String.format("Monitors=[]TotalMonitors=0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(responseWithNullMonitor); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(monitorIdParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleMonitor() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("type", "monitor"); + content.field("name", testMonitor.getName()); + content.endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, testMonitor.getId(), null, null).sourceRef(BytesReference.bytes(content)); + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getMonitorsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String + .format("Monitors=[{id=%s,name=%s}]TotalMonitors=%d", testMonitor.getId(), testMonitor.getName(), hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("monitorName", "foo"); + validParams.put("enabled", "true"); + validParams.put("hasTriggers", "true"); + validParams.put("indices", "bar"); + validParams.put("sortOrder", "ASC"); + validParams.put("sortString", "baz"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + assertDoesNotThrow(() -> tool.run(Map.of("hasTriggers", "false"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorNamePattern", "foo*"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("detectorId", "foo"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("sortOrder", "AsC"), listener)); + } + + @Test + public void testValidate() { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchMonitorsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(monitorIdParams)); + assertTrue(tool.validate(nullParams)); + } +}