Skip to content

Commit

Permalink
Migrate alerting tools (#66) (#76)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2f06f6f)

Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent b0db422 commit 0d0c1c1
Show file tree
Hide file tree
Showing 6 changed files with 889 additions and 1 deletion.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ buildscript {
opensearch_version = System.getProperty("opensearch.version", "2.12.0-SNAPSHOT")
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")
kotlin_version = System.getProperty("kotlin.version", "1.8.21")
}

repositories {
Expand Down Expand Up @@ -117,6 +118,7 @@ dependencies {
implementation fileTree(dir: adJarDirectory, include: ["opensearch-anomaly-detection-${version}.jar"])
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"])
compileOnly "org.opensearch:common-utils:${version}"
compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}"

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${version}"
Expand Down
182 changes: 182 additions & 0 deletions src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.AlertingPluginInterface;
import org.opensearch.commons.alerting.action.GetAlertsRequest;
import org.opensearch.commons.alerting.action.GetAlertsResponse;
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(SearchAlertsTool.TYPE)
public class SearchAlertsTool implements Tool {
public static final String TYPE = "SearchAlertsTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to search alerts.";

@Setter
@Getter
private String name = TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String type;
@Getter
private String version;

private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;

public SearchAlertsTool(Client client) {
this.client = client;

// probably keep this overridden output parser. need to ensure the output matches what's expected
outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc");
final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword");
final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size"))
? Integer.parseInt(parameters.get("size"))
: 20;
final int startIndex = parameters.containsKey("startIndex") && StringUtils.isNumeric(parameters.get("startIndex"))
? Integer.parseInt(parameters.get("startIndex"))
: 0;
final String searchString = parameters.getOrDefault("searchString", null);

// not exposing "missing" from the table, using default of null
final Table table = new Table(tableSortOrder, tableSortString, null, tableSize, startIndex, searchString);

final String severityLevel = parameters.getOrDefault("severityLevel", "ALL");
final String alertState = parameters.getOrDefault("alertState", "ALL");
final String monitorId = parameters.getOrDefault("monitorId", null);
final String alertIndex = parameters.getOrDefault("alertIndex", null);
@SuppressWarnings("unchecked")
final List<String> monitorIds = parameters.containsKey("monitorIds")
? gson.fromJson(parameters.get("monitorIds"), List.class)
: null;
@SuppressWarnings("unchecked")
final List<String> workflowIds = parameters.containsKey("workflowIds")
? gson.fromJson(parameters.get("workflowIds"), List.class)
: null;
@SuppressWarnings("unchecked")
final List<String> alertIds = parameters.containsKey("alertIds") ? gson.fromJson(parameters.get("alertIds"), List.class) : null;

GetAlertsRequest getAlertsRequest = new GetAlertsRequest(
table,
severityLevel,
alertState,
monitorId,
alertIndex,
monitorIds,
workflowIds,
alertIds
);

// create response listener
// stringify the response, may change to a standard format in the future
ActionListener<GetAlertsResponse> getAlertsListener = ActionListener.<GetAlertsResponse>wrap(response -> {
StringBuilder sb = new StringBuilder();
sb.append("Alerts=[");
for (Alert alert : response.getAlerts()) {
sb.append(alert.toString());
}
sb.append("]");
sb.append("TotalAlerts=").append(response.getTotalAlerts());
listener.onResponse((T) sb.toString());
}, e -> {
log.error("Failed to search alerts.", e);
listener.onFailure(e);
});

// execute the search
AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, getAlertsRequest, getAlertsListener);
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
}

@Override
public String getType() {
return TYPE;
}

/**
* Factory for the {@link SearchAlertsTool}
*/
public static class Factory implements Tool.Factory<SearchAlertsTool> {
private Client client;

private static Factory INSTANCE;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchAlertsTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

/**
* Initialize this factory
* @param client The OpenSearch client
*/
public void init(Client client) {
this.client = client;
}

@Override
public SearchAlertsTool create(Map<String, Object> map) {
return new SearchAlertsTool(client);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(SearchAnomalyDetectorsTool.TYPE)
public class SearchAnomalyDetectorsTool implements Tool {
public static final String TYPE = "SearchAnomalyDetectorsTool";
Expand Down Expand Up @@ -140,7 +142,10 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
sb.append("]");
sb.append("TotalAnomalyDetectors=").append(response.getHits().getTotalHits().value);
listener.onResponse((T) sb.toString());
}, e -> { listener.onFailure(e); });
}, e -> {
log.error("Failed to search anomaly detectors.", e);
listener.onFailure(e);
});

adClient.searchAnomalyDetectors(searchDetectorRequest, searchDetectorListener);
}
Expand Down
Loading

0 comments on commit 0d0c1c1

Please sign in to comment.