Skip to content

Commit

Permalink
feat: parse connector id from tool parameters map (#846)
Browse files Browse the repository at this point in the history
* feat: parse connector id from tool parameters map

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* update changelog

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* implement unit test for connector, model and agent id

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* tool step id: make node id unique

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* integration test: create agent with connector tool

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* integration test: update with get agent and get workflow

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* optimize: iterate through connector_id model_id and agent_id

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

* update changelog

Signed-off-by: yuye-aws <yuyezhu@amazon.com>

---------

Signed-off-by: yuye-aws <yuyezhu@amazon.com>
  • Loading branch information
yuye-aws authored Aug 28, 2024
1 parent 60458a6 commit b3f9d65
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 99 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Features
- Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804))
- Adds user level access control based on backend roles ([#838](https://github.com/opensearch-project/flow-framework/pull/838))
- Support parsing connector_id when creating tools ([#846](https://github.com/opensearch-project/flow-framework/pull/846))

### Enhancements
### Bug Fixes
Expand Down
53 changes: 26 additions & 27 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
Expand Down Expand Up @@ -64,7 +65,15 @@ public PlainActionFuture<WorkflowData> execute(
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = ParseUtils.parseIfExists(inputs, INCLUDE_OUTPUT_IN_AGENT_RESPONSE, Boolean.class);
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

// parse connector_id, model_id and agent_id from previous node inputs
Set<String> toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID);
Map<String, String> parameters = getToolsParametersMap(
inputs.get(PARAMETERS_FIELD),
previousNodeInputs,
outputs,
toolParameterKeys
);

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

Expand Down Expand Up @@ -110,39 +119,29 @@ public String getName() {
private Map<String, String> getToolsParametersMap(
Object parameters,
Map<String, String> previousNodeInputs,
Map<String, WorkflowData> outputs
Map<String, WorkflowData> outputs,
Set<String> toolParameterKeys
) {
@SuppressWarnings("unchecked")
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNodeModel = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

Optional<String> previousNodeAgent = previousNodeInputs.entrySet()
.stream()
.filter(e -> AGENT_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
}
}

// Case when agentId is passed through previousSteps and not present already in parameters
if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) {
parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString());
for (String toolParameterKey : toolParameterKeys) {
Optional<String> previousNodeParameter = previousNodeInputs.entrySet()
.stream()
.filter(e -> toolParameterKey.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when toolParameterKey is passed through previousSteps and not present already in parameters
if (previousNodeParameter.isPresent() && !parametersMap.containsKey(toolParameterKey)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeParameter.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(toolParameterKey)) {
parametersMap.put(toolParameterKey, previousNodeOutput.getContent().get(toolParameterKey).toString());
}
}
}

// For other cases where modelId is already present in the parameters or not return the parametersMap
// For other cases where toolParameterKey is already present in the parameters or not return the parametersMap
return parametersMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,23 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
);
}

/**
* Helper method to invoke the Get Agent Rest Action
* @param client the rest client
* @return rest response
* @throws Exception
*/
protected Response getAgent(RestClient client, String agentId) throws Exception {
return TestHelpers.makeRequest(
client,
"GET",
String.format(Locale.ROOT, "/_plugins/_ml/agents/%s", agentId),
Collections.emptyMap(),
"",
null
);
}

/**
* Helper method to invoke the Search Workflow Rest Action with the given query
* @param client the rest client
Expand All @@ -668,7 +685,6 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
* @throws Exception if the request fails
*/
protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception {

// Execute search
Response restSearchResponse = TestHelpers.makeRequest(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
Expand All @@ -56,7 +57,6 @@ public void waitToStart() throws Exception {
}

public void testSearchWorkflows() throws Exception {

// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");
Response response = createWorkflow(client(), template);
Expand Down Expand Up @@ -228,7 +228,6 @@ public void testCreateAndProvisionCyclicalTemplate() throws Exception {
}

public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {

// Using a 3 step template to create a connector, register remote model and deploy model
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");

Expand Down Expand Up @@ -331,6 +330,79 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception {
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws Exception {
// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-createconnectortool-createflowagent.json");

// Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter
Response response = createWorkflowWithProvision(client(), template);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
Map<String, Object> responseMap = entityAsMap(response);
String workflowId = (String) responseMap.get(WORKFLOW_ID);
// wait and ensure state is completed/done
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); },
120,
TimeUnit.SECONDS
);

// Assert based on the agent-framework template
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 120);
Map<String, ResourceCreated> resourceMap = resourcesCreated.stream()
.collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r));
assertEquals(2, resourceMap.size());
assertTrue(resourceMap.containsKey("create_connector"));
assertTrue(resourceMap.containsKey("register_agent"));
String connectorId = resourceMap.get("create_connector").resourceId();
String agentId = resourceMap.get("register_agent").resourceId();
assertNotNull(connectorId);
assertNotNull(agentId);

// Assert that the agent contains the correct connector_id
response = getAgent(client(), agentId);
Map<String, Object> agentResponse = entityAsMap(response);
assertTrue(agentResponse.containsKey("tools"));
@SuppressWarnings("unchecked")
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) agentResponse.get("tools");
assertEquals(1, tools.size());
Map<String, Object> tool = tools.getFirst();
assertTrue(tool.containsKey("parameters"));
@SuppressWarnings("unchecked")
Map<String, String> toolParameters = (Map<String, String>) tool.get("parameters");
assertEquals(toolParameters, Map.of("connector_id", connectorId));

// Hit Deprovision API
// By design, this may not completely deprovision the first time if it takes >2s to process removals
Response deprovisionResponse = deprovisionWorkflow(client(), workflowId);
try {
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
} catch (ComparisonFailure e) {
// 202 return if still processing
assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse));
}
if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) {
// Short wait before we try again
Thread.sleep(10000);
deprovisionResponse = deprovisionWorkflow(client(), workflowId);
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
}
assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse));
// Hit Delete API
Response deleteResponse = deleteWorkflow(client(), workflowId);
assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse));

// Verify state doc is deleted
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testReprovisionWorkflow() throws Exception {
// Begin with a template to register a local pretrained model
Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json");
Expand Down Expand Up @@ -650,7 +722,6 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
}

public void testDefaultCohereUseCase() throws Exception {

// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCaseWithNoValidation(
client(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

public class ToolStepTests extends OpenSearchTestCase {
private WorkflowData inputData;
private WorkflowData inputDataWithConnectorId;
private WorkflowData inputDataWithModelId;
private WorkflowData inputDataWithAgentId;
private static final String mockedConnectorId = "mocked-connector-id";
private static final String mockedModelId = "mocked-model-id";
private static final String mockedAgentId = "mocked-agent-id";
private static final String createConnectorNodeId = "create_connector_node_id";
private static final String createModelNodeId = "create_model_node_id";
private static final String createAgentNodeId = "create_agent_node_id";

private WorkflowData boolStringInputData;
private WorkflowData badBoolInputData;

Expand All @@ -39,6 +52,9 @@ public void setUp() throws Exception {
"test-id",
"test-node-id"
);
inputDataWithConnectorId = new WorkflowData(Map.of(CONNECTOR_ID, mockedConnectorId), "test-id", createConnectorNodeId);
inputDataWithModelId = new WorkflowData(Map.of(MODEL_ID, mockedModelId), "test-id", createModelNodeId);
inputDataWithAgentId = new WorkflowData(Map.of(AGENT_ID, mockedAgentId), "test-id", createAgentNodeId);
boolStringInputData = new WorkflowData(
Map.ofEntries(
Map.entry("type", "type"),
Expand All @@ -63,7 +79,7 @@ public void setUp() throws Exception {
);
}

public void testTool() throws IOException, ExecutionException, InterruptedException {
public void testTool() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
Expand All @@ -88,7 +104,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
}

public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException {
public void testBoolParseFail() {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
Expand All @@ -100,10 +116,61 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup
);

assertTrue(future.isDone());
ExecutionException e = assertThrows(ExecutionException.class, () -> future.get());
ExecutionException e = assertThrows(ExecutionException.class, future::get);
assertEquals(WorkflowStepException.class, e.getCause().getClass());
WorkflowStepException w = (WorkflowStepException) e.getCause();
assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage());
assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus());
}

public void testToolWithConnectorId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createConnectorNodeId, inputDataWithConnectorId),
Map.of(createConnectorNodeId, CONNECTOR_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(CONNECTOR_ID, mockedConnectorId));
}

public void testToolWithModelId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createModelNodeId, inputDataWithModelId),
Map.of(createModelNodeId, MODEL_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(MODEL_ID, mockedModelId));
}

public void testToolWithAgentId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createAgentNodeId, inputDataWithAgentId),
Map.of(createAgentNodeId, AGENT_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(AGENT_ID, mockedAgentId));
}
}
Loading

0 comments on commit b3f9d65

Please sign in to comment.