Skip to content

Commit

Permalink
Add LogPatternTool
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <qianheng@amazon.com>
  • Loading branch information
qianheng-aws committed Sep 26, 2024
1 parent 58553eb commit 2a5cbc2
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.agent.tools.CreateAlertTool;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.LogPatternTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.RAGTool;
Expand Down Expand Up @@ -71,6 +72,7 @@ public Collection<Object> createComponents(
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
LogPatternTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

Expand All @@ -87,7 +89,8 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
CreateAnomalyDetectorTool.Factory.getInstance(),
LogPatternTool.Factory.getInstance()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected AbstractRetrieverTool(

protected abstract String getQueryBody(String queryText);

private static Map<String, Object> processResponse(SearchHit hit) {
protected static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
Expand All @@ -77,7 +77,7 @@ private static Map<String, Object> processResponse(SearchHit hit) {
return docContent;
}

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
protected <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it.");
Expand Down
255 changes: 255 additions & 0 deletions src/main/java/org/opensearch/agent/tools/LogPatternTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHit;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports generating log patterns on the input dsl and index. It's implemented by
* several steps:
* 1. Retrival [[${DOC_SIZE_FIELD}]] logs from index
* 2. Extract patterns for each retrieved log
* 2.1 Find Pattern Field: If users provide parameter [[${PATTERN_FIELD}]], use it as the pattern
* field; Otherwise, find the string field with the longest length on the first log.
* 2.2 Extract Pattern: If users provide parameter [[${PATTERN}]], compile it as a pattern;
* Otherwise, use [[${DEFAULT_IGNORED_CHARS}]]. It will remove all chars matching the pattern.
* 3. Group logs by their extracted patterns.
* 4. Find top N patterns with the largest sample log size.
* 5. For each found top N patterns, return [[${SAMPLE_LOG_SIZE}]] sample logs.
*/
@Log4j2
@Getter
@Setter
@ToolAnnotation(LogPatternTool.TYPE)
public class LogPatternTool extends AbstractRetrieverTool {
public static final String TYPE = "LogPatternTool";

public static String DEFAULT_DESCRIPTION = "Log Pattern Tool";
public static final String TOP_N_PATTERN = "top_n_pattern";
public static final String SAMPLE_LOG_SIZE = "sample_log_size";
public static final String DSL_FIELD = "dsl";
public static final String PATTERN_FIELD = "pattern_field";
public static final String PATTERN = "pattern";
public static final int LOG_PATTERN_DEFAULT_DOC_SIZE = 1000;
public static final int DEFAULT_TOP_N_PATTERN = 3;
public static final int DEFAULT_SAMPLE_LOG_SIZE = 20;
private static final ImmutableSet<Character> DEFAULT_IGNORED_CHARS = ImmutableSet
.copyOf("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".chars().mapToObj(c -> (char) c).toArray(Character[]::new));

private String name = TYPE;
private int topNPattern;
private int sampleLogSize;
@EqualsAndHashCode.Exclude
private Pattern pattern;

@Builder
public LogPatternTool(
Client client,
NamedXContentRegistry xContentRegistry,
int docSize,
int topNPattern,
int sampleLogSize,
String patternStr
) {
super(client, xContentRegistry, null, null, docSize);
this.topNPattern = topNPattern;
this.sampleLogSize = sampleLogSize;
if (pattern != null)
this.pattern = Pattern.compile(patternStr);
}

@Override
protected String getQueryBody(String queryText) {
return queryText;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
super.index = parameters.get(INDEX_FIELD);
if (parameters.containsKey(DOC_SIZE_FIELD))
super.docSize = Integer.parseInt(parameters.get(DOC_SIZE_FIELD));
if (parameters.containsKey(TOP_N_PATTERN))
topNPattern = Integer.parseInt(parameters.get(TOP_N_PATTERN));
if (parameters.containsKey(SAMPLE_LOG_SIZE))
sampleLogSize = Integer.parseInt(parameters.get(SAMPLE_LOG_SIZE));
if (parameters.containsKey(PATTERN))
this.pattern = Pattern.compile((parameters.get(PATTERN)));

SearchRequest searchRequest;
try {
searchRequest = buildSearchRequest(parameters);
} catch (Exception e) {
log.error("Failed to build search request.", e);
listener.onFailure(e);
return;
}

ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
String patternField = parameters.containsKey(PATTERN_FIELD)
? parameters.get(PATTERN_FIELD)
: findLongestField(hits[0].getSourceAsMap());
if (patternField == null) {
listener.onResponse((T) "Pattern field is not set and this index doesn't contain any string field");
}
Map<String, List<Map<String, Object>>> patternGroups = new HashMap<>();
for (SearchHit hit : hits) {
Map<String, Object> source = hit.getSourceAsMap();
String pattern = extractPattern((String) source.getOrDefault(patternField, ""), this.pattern);
List<Map<String, Object>> group = patternGroups.computeIfAbsent(pattern, k -> new ArrayList<>());
group.add(source);
}
List<Map<String, Object>> sortedEntries = patternGroups
.entrySet()
.stream()
.sorted(Comparator.comparingInt(entry -> -entry.getValue().size()))
.limit(topNPattern)
.map(
entry -> Map
.of(
"total count",
entry.getValue().size(),
"pattern",
entry.getKey(),
"sample logs",
entry.getValue().subList(0, Math.min(entry.getValue().size(), sampleLogSize))
)
)
.toList();

listener
.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(sortedEntries)));
} else {
listener.onResponse((T) "Can not get any match from search result.");
}
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
client.search(searchRequest, actionListener);
}

@VisibleForTesting
public static String extractPattern(String rawString, Pattern pattern) {
if (pattern != null)
return pattern.matcher(rawString).replaceAll("");
char[] chars = rawString.toCharArray();
int pos = 0;
for (int i = 0; i < chars.length; i++) {
if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) {
chars[pos++] = chars[i];
}
}
return new String(chars, 0, pos);
}

@VisibleForTesting
public static String findLongestField(Map<String, Object> sampleLogSource) {
String longestField = null;
int maxLength = 0;

for (Map.Entry<String, Object> entry : sampleLogSource.entrySet()) {
Object value = entry.getValue();
if (value instanceof String) { // 确保值是字符串类型
String stringValue = (String) value;
int length = stringValue.length();
if (length > maxLength) {
maxLength = length;
longestField = entry.getKey();
}
}
}
return longestField;
}

@Override
public String getType() {
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<LogPatternTool> {
private static LogPatternTool.Factory INSTANCE;

public static LogPatternTool.Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (LogPatternTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new LogPatternTool.Factory();
return INSTANCE;
}
}

@Override
public LogPatternTool create(Map<String, Object> params) {
int docSize = params.containsKey(DOC_SIZE_FIELD)
? Integer.parseInt((String) params.get(DOC_SIZE_FIELD))
: LOG_PATTERN_DEFAULT_DOC_SIZE;
int topNPattern = params.containsKey(TOP_N_PATTERN)
? Integer.parseInt((String) params.get(TOP_N_PATTERN))
: DEFAULT_TOP_N_PATTERN;
int sampleLogSize = params.containsKey(SAMPLE_LOG_SIZE)
? Integer.parseInt((String) params.get(SAMPLE_LOG_SIZE))
: DEFAULT_SAMPLE_LOG_SIZE;
String patternStr = params.containsKey(PATTERN) ? (String) params.get(PATTERN) : null;
return LogPatternTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.docSize(docSize)
.topNPattern(topNPattern)
.sampleLogSize(sampleLogSize)
.patternStr(patternStr)
.build();
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
55 changes: 55 additions & 0 deletions src/test/java/org/opensearch/agent/tools/LogPatternToolTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
import org.junit.Test;

import lombok.SneakyThrows;

public class LogPatternToolTests {

public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh";
private Map<String, Object> params = new HashMap<>();

@Before
public void setup() {}

@Test
@SneakyThrows
public void testCreateTool() {
LogPatternTool tool = LogPatternTool.Factory.getInstance().create(params);
assertEquals(LogPatternTool.LOG_PATTERN_DEFAULT_DOC_SIZE, (int) tool.docSize);
assertEquals(LogPatternTool.DEFAULT_TOP_N_PATTERN, tool.getTopNPattern());
assertEquals(LogPatternTool.DEFAULT_SAMPLE_LOG_SIZE, tool.getSampleLogSize());
assertNull(tool.getPattern());
assertEquals("LogPatternTool", tool.getType());
assertEquals("LogPatternTool", tool.getName());
assertEquals(LogPatternTool.DEFAULT_DESCRIPTION, LogPatternTool.Factory.getInstance().getDefaultDescription());
}

@Test
public void testGetQueryBody() {
LogPatternTool tool = LogPatternTool.Factory.getInstance().create(params);
assertEquals(TEST_QUERY_TEXT, tool.getQueryBody(TEST_QUERY_TEXT));
}

@Test
public void testFindLongestField() {
assertEquals("field2", LogPatternTool.findLongestField(Map.of("field1", "123", "field2", "1234", "filed3", 1234)));
}

@Test
public void testExtractPattern() {
assertEquals("././", LogPatternTool.extractPattern("123.abc/.AB/", null));
}
}

0 comments on commit 2a5cbc2

Please sign in to comment.