Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PainlessScript tool #380

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ buildscript {
plugins {
id 'java-library'
id 'com.diffplug.spotless' version '6.25.0'
id "io.freefair.lombok" version "8.6"
id "io.freefair.lombok" version "8.10"
id "de.undercouch.download" version "5.6.0"
}

Expand Down Expand Up @@ -121,7 +121,7 @@ dependencies {
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1"
compileOnly group: 'org.json', name: 'json', version: '20240303'
compileOnly("com.google.guava:guava:33.2.1-jre")
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0'
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0'

// Plugin dependencies
Expand All @@ -148,12 +148,12 @@ dependencies {
testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.json', name: 'json', version: '20240303'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.11.0'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.13.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0'
testImplementation("net.bytebuddy:byte-buddy:1.14.9")
testImplementation("net.bytebuddy:byte-buddy-agent:1.14.12")
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.2'
testImplementation 'org.mockito:mockito-junit-jupiter:5.11.0'
testImplementation 'org.mockito:mockito-junit-jupiter:5.13.0'
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0"
testImplementation "com.cronutils:cron-utils:9.2.1"
testImplementation "commons-validator:commons-validator:1.8.0"
Expand Down
Binary file modified gradle/wrapper/gradle-wrapper.jar
Binary file not shown.
4 changes: 2 additions & 2 deletions gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=d725d707bfabd4dfdc958c624003b3c80accc03f7037b5122c4b1d0ef15cecab
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
distributionSha256Sum=5b9c5eb3f9fc2c94abaea57d90bd78747ca117ddbbf96c859d3741181a12bf2a
distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.PainlessScriptTool;
import org.opensearch.agent.tools.RAGTool;
import org.opensearch.agent.tools.SearchAlertsTool;
import org.opensearch.agent.tools.SearchAnomalyDetectorsTool;
Expand Down Expand Up @@ -71,6 +72,7 @@ public Collection<Object> createComponents(
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
PainlessScriptTool.Factory.getInstance().init(scriptService);
return Collections.emptyList();
}

Expand All @@ -87,7 +89,8 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
CreateAnomalyDetectorTool.Factory.getInstance(),
PainlessScriptTool.Factory.getInstance()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.get.GetIndexRequest;
import org.opensearch.action.support.IndicesOptions;
Expand All @@ -27,6 +26,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.logging.LoggerMessageFormat;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -275,7 +275,7 @@ public void init(Client client) {
@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID);
if (Strings.isBlank(modelId)) {
if (Strings.isNullOrEmpty(modelId) || modelId.isBlank()) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -44,7 +45,6 @@

import com.google.common.collect.ImmutableMap;

import joptsimple.internal.Strings;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down
154 changes: 154 additions & 0 deletions src/main/java/org/opensearch/agent/tools/PainlessScriptTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
import org.opensearch.script.TemplateScript;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* use case for this tool will only focus on flow agent
*/
@Log4j2
@Setter
@Getter
@ToolAnnotation(PainlessScriptTool.TYPE)
public class PainlessScriptTool implements Tool {
public static final String TYPE = "PainlessTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to execute painless script";

@Setter
@Getter
private String name = TYPE;

@Getter
private String type = TYPE;

@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

@Getter
private String version;

private ScriptService scriptService;
private String scriptCode;

public PainlessScriptTool(ScriptService scriptEngine, String script) {
this.scriptService = scriptEngine;
this.scriptCode = script;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap());
Map<String, Object> flattenedParameters = getFlattenedParameters(parameters);
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters);
try {
String result = templateScript.execute();
listener.onResponse(result == null ? (T) "" : (T) result);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
public boolean validate(Map<String, String> map) {
return true;
}

Map<String, Object> getFlattenedParameters(Map<String, String> parameters) {
Map<String, Object> flattenedParameters = new HashMap<>();
for (Map.Entry<String, String> entry : parameters.entrySet()) {
// keep both original values and flatten
flattenedParameters.put(entry.getKey(), entry.getValue());
try {
// default is json parser, we may add more...
String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue());
Map<String, ?> map = StringUtils.fromJson(value, "");
flattenMap(map, flattenedParameters, entry.getKey());
} catch (Throwable ignored) {}
}
return flattenedParameters;
}

void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String prefix) {
for (Map.Entry<String, ?> entry : map.entrySet()) {
String key = entry.getKey();
if (prefix != null && !prefix.isEmpty()) {
key = prefix + "." + entry.getKey();
}
Object value = entry.getValue();
if (value instanceof Map) {
flattenMap((Map<String, ?>) value, flatMap, key);
} else {
flatMap.put(key, value);
}
}
}

public static class Factory implements Tool.Factory<PainlessScriptTool> {
private ScriptService scriptService;

private static PainlessScriptTool.Factory INSTANCE;

public static PainlessScriptTool.Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (PainlessScriptTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new PainlessScriptTool.Factory();
return INSTANCE;
}
}

public void init(ScriptService scriptService) {
this.scriptService = scriptService;
}

@Override
public PainlessScriptTool create(Map<String, Object> map) {
String script = (String) map.get("script");
if (Strings.isNullOrEmpty(script)) {
throw new IllegalArgumentException("script is required");
}
return new PainlessScriptTool(scriptService, script);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}

}
}
Loading