Skip to content

Commit

Permalink
integration test: update with get agent and get workflow
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <yuyezhu@amazon.com>
  • Loading branch information
yuye-aws committed Aug 28, 2024
1 parent de1c044 commit a8872d6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,51 +661,35 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
}

/**
* Helper method to invoke the Search Workflow Rest Action with the given query
* Helper method to invoke the Get Agent Rest Action
* @param client the rest client
* @param query the search query
* @return rest response
* @throws Exception if the request fails
* @throws Exception
*/
protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception {

// Execute search
Response restSearchResponse = TestHelpers.makeRequest(
protected Response getAgent(RestClient client, String agentId) throws Exception {
return TestHelpers.makeRequest(
client,
"GET",
String.format(Locale.ROOT, "%s/_search", WORKFLOW_URI),
String.format(Locale.ROOT, "/_plugins/_ml/agents/%s", agentId),
Collections.emptyMap(),
query,
"",
null
);
assertEquals(RestStatus.OK, TestHelpers.restStatus(restSearchResponse));

// Parse entity content into SearchResponse
MediaType mediaType = MediaType.fromMediaType(restSearchResponse.getEntity().getContentType());
try (
XContentParser parser = mediaType.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
restSearchResponse.getEntity().getContent()
)
) {
return SearchResponse.fromXContent(parser);
}
}

/**
* Helper method to invoke the Search Workflow State Rest Action
* Helper method to invoke the Search Workflow Rest Action with the given query
* @param client the rest client
* @param query the search query
* @return
* @throws Exception
* @return rest response
* @throws Exception if the request fails
*/
protected SearchResponse searchWorkflowState(RestClient client, String query) throws Exception {
protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception {
// Execute search
Response restSearchResponse = TestHelpers.makeRequest(
client,
"GET",
String.format(Locale.ROOT, "%s/state/_search", WORKFLOW_URI),
String.format(Locale.ROOT, "%s/_search", WORKFLOW_URI),
Collections.emptyMap(),
query,
null
Expand All @@ -727,17 +711,17 @@ protected SearchResponse searchWorkflowState(RestClient client, String query) th
}

/**
* Helper method to invoke the Search Agent Rest Action
* Helper method to invoke the Search Workflow State Rest Action
* @param client the rest client
* @param query the search query
* @return
* @throws Exception
*/
protected SearchResponse searchAgent(RestClient client, String query) throws Exception {
protected SearchResponse searchWorkflowState(RestClient client, String query) throws Exception {
Response restSearchResponse = TestHelpers.makeRequest(
client,
"GET",
"/_plugins/_ml/agents/_search",
String.format(Locale.ROOT, "%s/state/_search", WORKFLOW_URI),
Collections.emptyMap(),
query,
null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.search.SearchHit;
import org.junit.Before;
import org.junit.ComparisonFailure;

Expand Down Expand Up @@ -347,36 +346,24 @@ public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws E
TimeUnit.SECONDS
);

// Hit Search State API with the workflow id created above
String query = "{\"query\":{\"ids\":{\"values\":[\"" + workflowId + "\"]}}}";
SearchResponse searchResponse = searchWorkflowState(client(), query);
assertEquals(1, searchResponse.getHits().getTotalHits().value);
String searchHitSource = searchResponse.getHits().getAt(0).getSourceAsString();
WorkflowState searchHitWorkflowState = WorkflowState.parse(searchHitSource);

// Assert based on the agent-framework template
List<ResourceCreated> resourcesCreated = searchHitWorkflowState.resourcesCreated();
Set<String> expectedStepNames = new HashSet<>();
expectedStepNames.add("create_connector");
expectedStepNames.add("create_flow_agent");
Set<String> stepNames = resourcesCreated.stream().map(ResourceCreated::workflowStepId).collect(Collectors.toSet());

assertEquals(2, resourcesCreated.size());
assertEquals(stepNames, expectedStepNames);
String connectorId = resourcesCreated.getFirst().resourceId();
String agentId = resourcesCreated.get(1).resourceId();
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 120);
Map<String, ResourceCreated> resourceMap = resourcesCreated.stream()
.collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r));
assertEquals(2, resourceMap.size());
assertTrue(resourceMap.containsKey("create_connector"));
assertTrue(resourceMap.containsKey("register_agent"));
String connectorId = resourceMap.get("create_connector").resourceId();
String agentId = resourceMap.get("register_agent").resourceId();
assertNotNull(connectorId);
assertNotNull(agentId);

query = "{\"query\":{\"ids\":{\"values\":[\"" + agentId + "\"]}}}";
searchResponse = searchAgent(client(), query);
assertEquals(1, searchResponse.getHits().getTotalHits().value);
SearchHit searchHit = searchResponse.getHits().getAt(0);
Map<String, Object> searchHitSourceMap = searchHit.getSourceAsMap();
assertTrue(searchHitSourceMap.containsKey("tools"));

// Assert that the agent contains the correct connector_id
response = getAgent(client(), agentId);
Map<String, Object> agentResponse = entityAsMap(response);
assertTrue(agentResponse.containsKey("tools"));
@SuppressWarnings("unchecked")
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) searchHitSourceMap.get("tools");
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) agentResponse.get("tools");
assertEquals(1, tools.size());
Map<String, Object> tool = tools.getFirst();
assertTrue(tool.containsKey("parameters"));
Expand Down Expand Up @@ -735,7 +722,6 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
}

public void testDefaultCohereUseCase() throws Exception {

// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCaseWithNoValidation(
client(),
Expand Down

0 comments on commit a8872d6

Please sign in to comment.