diff --git a/build.gradle b/build.gradle index 4a7fc9ba..dd654eba 100644 --- a/build.gradle +++ b/build.gradle @@ -44,7 +44,7 @@ buildscript { plugins { id 'java-library' id 'com.diffplug.spotless' version '6.25.0' - id "io.freefair.lombok" version "8.6" + id "io.freefair.lombok" version "8.10" id "de.undercouch.download" version "5.6.0" } @@ -121,7 +121,7 @@ dependencies { compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1" compileOnly group: 'org.json', name: 'json', version: '20240303' compileOnly("com.google.guava:guava:33.2.1-jre") - compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' // Plugin dependencies @@ -148,12 +148,12 @@ dependencies { testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.json', name: 'json', version: '20240303' - testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.11.0' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.13.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' testImplementation("net.bytebuddy:byte-buddy:1.14.9") testImplementation("net.bytebuddy:byte-buddy-agent:1.14.12") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.2' - testImplementation 'org.mockito:mockito-junit-jupiter:5.11.0' + testImplementation 'org.mockito:mockito-junit-jupiter:5.13.0' testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" testImplementation "com.cronutils:cron-utils:9.2.1" testImplementation "commons-validator:commons-validator:1.8.0" diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 2c352119..a4b76b95 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 68e8816d..2b189974 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=d725d707bfabd4dfdc958c624003b3c80accc03f7037b5122c4b1d0ef15cecab -distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip +distributionSha256Sum=5b9c5eb3f9fc2c94abaea57d90bd78747ca117ddbbf96c859d3741181a12bf2a +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 8fa1ac40..b964ef9f 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -14,6 +14,7 @@ import org.opensearch.agent.tools.CreateAnomalyDetectorTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; +import org.opensearch.agent.tools.PainlessScriptTool; import org.opensearch.agent.tools.RAGTool; import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; @@ -71,6 +72,7 @@ public Collection createComponents( SearchMonitorsTool.Factory.getInstance().init(client); CreateAlertTool.Factory.getInstance().init(client); CreateAnomalyDetectorTool.Factory.getInstance().init(client); + PainlessScriptTool.Factory.getInstance().init(scriptService); return Collections.emptyList(); } @@ -87,7 +89,8 @@ public List> getToolFactories() { SearchAnomalyResultsTool.Factory.getInstance(), SearchMonitorsTool.Factory.getInstance(), CreateAlertTool.Factory.getInstance(), - CreateAnomalyDetectorTool.Factory.getInstance() + CreateAnomalyDetectorTool.Factory.getInstance(), + PainlessScriptTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 1d6b1c36..34d85700 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -18,7 +18,6 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; -import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.indices.get.GetIndexRequest; import org.opensearch.action.support.IndicesOptions; @@ -27,6 +26,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.logging.LoggerMessageFormat; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -275,7 +275,7 @@ public void init(Client client) { @Override public CreateAlertTool create(Map params) { String modelId = (String) params.get(MODEL_ID); - if (Strings.isBlank(modelId)) { + if (Strings.isNullOrEmpty(modelId) || modelId.isBlank()) { throw new IllegalArgumentException("model_id cannot be null or blank."); } String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString()); diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 52811e61..bd018698 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -31,6 +31,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -44,7 +45,6 @@ import com.google.common.collect.ImmutableMap; -import joptsimple.internal.Strings; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; diff --git a/src/main/java/org/opensearch/agent/tools/PainlessScriptTool.java b/src/main/java/org/opensearch/agent/tools/PainlessScriptTool.java new file mode 100644 index 00000000..69b1f87a --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/PainlessScriptTool.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * use case for this tool will only focus on flow agent + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(PainlessScriptTool.TYPE) +public class PainlessScriptTool implements Tool { + public static final String TYPE = "PainlessTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to execute painless script"; + + @Setter + @Getter + private String name = TYPE; + + @Getter + private String type = TYPE; + + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private ScriptService scriptService; + private String scriptCode; + + public PainlessScriptTool(ScriptService scriptEngine, String script) { + this.scriptService = scriptEngine; + this.scriptCode = script; + } + + @Override + public void run(Map parameters, ActionListener listener) { + Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap()); + Map flattenedParameters = getFlattenedParameters(parameters); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters); + try { + String result = templateScript.execute(); + listener.onResponse(result == null ? (T) "" : (T) result); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public boolean validate(Map map) { + return true; + } + + Map getFlattenedParameters(Map parameters) { + Map flattenedParameters = new HashMap<>(); + for (Map.Entry entry : parameters.entrySet()) { + // keep both original values and flatten + flattenedParameters.put(entry.getKey(), entry.getValue()); + try { + // default is json parser, we may add more... + String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue()); + Map map = StringUtils.fromJson(value, ""); + flattenMap(map, flattenedParameters, entry.getKey()); + } catch (Throwable ignored) {} + } + return flattenedParameters; + } + + void flattenMap(Map map, Map flatMap, String prefix) { + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + if (prefix != null && !prefix.isEmpty()) { + key = prefix + "." + entry.getKey(); + } + Object value = entry.getValue(); + if (value instanceof Map) { + flattenMap((Map) value, flatMap, key); + } else { + flatMap.put(key, value); + } + } + } + + public static class Factory implements Tool.Factory { + private ScriptService scriptService; + + private static PainlessScriptTool.Factory INSTANCE; + + public static PainlessScriptTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (PainlessScriptTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new PainlessScriptTool.Factory(); + return INSTANCE; + } + } + + public void init(ScriptService scriptService) { + this.scriptService = scriptService; + } + + @Override + public PainlessScriptTool create(Map map) { + String script = (String) map.get("script"); + if (Strings.isNullOrEmpty(script)) { + throw new IllegalArgumentException("script is required"); + } + return new PainlessScriptTool(scriptService, script); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + } +} diff --git a/src/test/java/org/opensearch/agent/tools/PainlessScriptToolTests.java b/src/test/java/org/opensearch/agent/tools/PainlessScriptToolTests.java new file mode 100644 index 00000000..71aa8633 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/PainlessScriptToolTests.java @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.text.StringEscapeUtils; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.script.ScriptService; +import org.opensearch.script.TemplateScript; + +import com.google.gson.Gson; + +/** + * this is a test file to test PainlessTool with junit + */ +public class PainlessScriptToolTests { + @Mock + private ScriptService scriptService; + @Mock + private TemplateScript templateScript; + @Mock + private ActionListener actionListener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + TemplateScript.Factory factory = new TemplateScript.Factory() { + @Override + public TemplateScript newInstance(Map params) { + return templateScript; + } + }; + + when(scriptService.compile(any(), any())).thenReturn(factory); + + PainlessScriptTool.Factory.getInstance().init(scriptService); + } + + @Test + public void testRun() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + when(templateScript.execute()).thenReturn("hello"); + tool.run(Map.of(), actionListener); + + verify(templateScript).execute(); + verify(scriptService).compile(any(), any()); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("hello", responseCaptor.getValue()); + } + + // test run wit exception + @Test + public void testRun_with_exception() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + when(templateScript.execute()).thenThrow(new RuntimeException("error")); + tool.run(Map.of(), actionListener); + + verify(templateScript).execute(); + verify(scriptService).compile(any(), any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("error", exceptionCaptor.getValue().getMessage()); + } + + // test factory create + @Test + public void testFactory_create() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + assertEquals(PainlessScriptTool.TYPE, tool.getType()); + assertEquals("PainlessTool", tool.getName()); + assertEquals("Use this tool to execute painless script", tool.getDescription()); + } + + // test factory create with exception + @Test(expected = IllegalArgumentException.class) + public void testFactory_create_with_exception() { + PainlessScriptTool.Factory.getInstance().create(Map.of()); + } + + // test flattenMap + @Test + public void testFlattenMap_without_prefix() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + Map map = Map.of("a", Map.of("b", "c"), "k", "v"); + Map resultMap = new HashMap<>(); + tool.flattenMap(map, resultMap, ""); + assertEquals(Map.of("a.b", "c", "k", "v"), resultMap); + } + + // with prefix + @Test + public void testFlattenMap_with_prefix() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + Map map = Map.of("a", Map.of("b", "c"), "k", "v"); + Map resultMap = new HashMap<>(); + tool.flattenMap(map, resultMap, "prefix"); + assertEquals(Map.of("prefix.a.b", "c", "prefix.k", "v"), resultMap); + } + + // nest map with depth 3 + @Test + public void testFlattenMap_with_depth_3() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + Map map = Map.of("a", Map.of("b", Map.of("c", "d"), "k", "v")); + Gson gson = new Gson(); + System.out.println(StringEscapeUtils.escapeJson(gson.toJson(map))); + Map resultMap = new HashMap<>(); + tool.flattenMap(map, resultMap, ""); + assertEquals(Map.of("a.b.c", "d", "a.k", "v"), resultMap); + } + + // test getFlattenedParameters + @Test + public void testGetFlattenedParameters() { + String script = "return 'Hello World';"; + PainlessScriptTool tool = PainlessScriptTool.Factory.getInstance().create(Map.of("script", script)); + Map map = Map.of("k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}"); + Map resultMap = tool.getFlattenedParameters(map); + assertEquals( + Map.of("k.a.b.c", "d", "k.a.k", "v", "k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}"), + resultMap + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index e0b04336..d5ab8530 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -93,7 +93,8 @@ public void setup() { null, null, null, - null + null, + true ); } diff --git a/src/test/java/org/opensearch/integTest/PainlessToolIT.java b/src/test/java/org/opensearch/integTest/PainlessToolIT.java new file mode 100644 index 00000000..70d2c11d --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PainlessToolIT.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.Assert; +import org.junit.Before; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class PainlessToolIT extends BaseAgentToolsIT { + + private String registerAgentRequestBody; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files + .readString( + Path.of(this.getClass().getClassLoader().getResource("org/opensearch/agent/tools/register_painless_agent.json").toURI()) + ); + } + + public void test_execute() { + String script = "def x = new HashMap(); x.abc = '5'; return x.abc;"; + String agentRequestBody = registerAgentRequestBody.replaceAll("