Skip to content

Commit

Permalink
Fix tools ordering class casting bug (opensearch-project#289)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis committed Dec 18, 2023
1 parent 597640f commit 466b998
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ private CommonValue() {}
public static final String RESOURCE_TYPE = "resource_type";
/** The field name for the resource id */
public static final String RESOURCE_ID = "resource_id";
/** The tools' field for an agent */
/** The tools field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The tools order field for an agent */
public static final String TOOLS_ORDER_FIELD = "tools_order";
/** The memory field for an agent */
public static final String MEMORY_FIELD = "memory";
/** The app type field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap;
Expand Down Expand Up @@ -94,7 +94,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (PipelineProcessor p : (PipelineProcessor[]) e.getValue()) {
xContentBuilder.value(p);
}
} else if (TOOLS_FIELD.equals(e.getKey())) {
} else if (TOOLS_ORDER_FIELD.equals(e.getKey())) {
for (String t : (String[]) e.getValue()) {
xContentBuilder.value(t);
}
Expand Down Expand Up @@ -156,7 +156,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
processorList.add(PipelineProcessor.parse(parser));
}
userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0]));
} else if (TOOLS_FIELD.equals(inputFieldName)) {
} else if (TOOLS_ORDER_FIELD.equals(inputFieldName)) {
List<String> toolsList = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
toolsList.add(parser.text());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

Expand All @@ -64,16 +65,13 @@ public class RegisterAgentStep implements WorkflowStep {
private static final String LLM_MODEL_ID = "llm.model_id";
private static final String LLM_PARAMETERS = "llm.parameters";

private List<MLToolSpec> mlToolSpecList;

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) {
this.mlClient = mlClient;
this.mlToolSpecList = new ArrayList<>();
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}

Expand Down Expand Up @@ -136,6 +134,7 @@ public void onFailure(Exception e) {
LLM_MODEL_ID,
LLM_PARAMETERS,
TOOLS_FIELD,
TOOLS_ORDER_FIELD,
PARAMETERS_FIELD,
MEMORY_FIELD,
CREATED_TIME,
Expand All @@ -157,8 +156,8 @@ public void onFailure(Exception e) {
String description = (String) inputs.get(DESCRIPTION_FIELD);
String llmModelId = (String) inputs.get(LLM_MODEL_ID);
Map<String, String> llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS);
String[] tools = (String[]) inputs.get(TOOLS_FIELD);
List<MLToolSpec> toolsList = getTools(tools, previousNodeInputs, outputs);
String[] toolsOrder = (String[]) inputs.get(TOOLS_ORDER_FIELD);
List<MLToolSpec> toolsList = getTools(toolsOrder, previousNodeInputs, outputs);
Map<String, String> parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD));
Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME));
Expand Down Expand Up @@ -285,7 +284,6 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) {
sessionId = (String) map.get(MLMemorySpec.SESSION_ID_FIELD);
windowSize = (Integer) map.get(MLMemorySpec.WINDOW_SIZE_FIELD);

@SuppressWarnings("unchecked")
MLMemorySpec.MLMemorySpecBuilder builder = MLMemorySpec.builder();

builder.type(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void testNode() throws IOException {
Map.entry("baz", new Map<?, ?>[] { Map.of("A", "a"), Map.of("B", "b") }),
Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }),
Map.entry("created_time", 1689793598499L),
Map.entry("tools", new String[] { "foo", "bar" })
Map.entry("tools_order", new String[] { "foo", "bar" })
)
);
assertEquals("A", nodeA.id());
Expand All @@ -47,7 +47,7 @@ public void testNode() throws IOException {
assertEquals("test-type", pp[0].type());
assertEquals(Map.of("key2", "value2"), pp[0].params());
assertEquals(1689793598499L, map.get("created_time"));
assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools"));
assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools_order"));

// node equality is based only on ID
WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz"));
Expand All @@ -65,7 +65,7 @@ public void testNode() throws IOException {
assertTrue(json.contains("\"bar\":{\"key\":\"value\"}"));
assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]"));
assertTrue(json.contains("\"created_time\":1689793598499"));
assertTrue(json.contains("\"tools\":[\"foo\",\"bar\"]"));
assertTrue(json.contains("\"tools_order\":[\"foo\",\"bar\"]"));

WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json));
assertEquals("A", nodeX.id());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -52,6 +54,10 @@ public void setUp() throws Exception {
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false);

LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap());

Map<?, ?> mlMemorySpec = Map.ofEntries(
Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"),
Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"),
Expand All @@ -65,7 +71,8 @@ public void setUp() throws Exception {
Map.entry("type", "type"),
Map.entry("llm.model_id", "xyz"),
Map.entry("llm.parameters", Collections.emptyMap()),
Map.entry("tools", new String[] { "abc", "xyz" }),
Map.entry("tools", tools),
Map.entry("tools_order", new String[] { "abc", "xyz" }),
Map.entry("parameters", Collections.emptyMap()),
Map.entry("memory", mlMemorySpec),
Map.entry("created_time", 1689793598499L),
Expand Down

0 comments on commit 466b998

Please sign in to comment.