Skip to content

Commit

Permalink
[Feature/agent_framework] Adds a Search Workflow State API (opensearc…
Browse files Browse the repository at this point in the history
…h-project#284)

* Modifying workflow state index mapping and resources created

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Adding Search workflow state API

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Adding rest unit tests

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Transport unit tests

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Moving resourceType determination outside of the resources created class

Signed-off-by: Joshua Palis <jpalis@amazon.com>

---------

Signed-off-by: Joshua Palis <jpalis@amazon.com>
  • Loading branch information
joshpalis authored and dbwiddis committed Dec 15, 2023
1 parent 0875701 commit 7d1dac1
Show file tree
Hide file tree
Showing 13 changed files with 403 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.flowframework.rest.RestGetWorkflowStateAction;
import org.opensearch.flowframework.rest.RestProvisionWorkflowAction;
import org.opensearch.flowframework.rest.RestSearchWorkflowAction;
import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction;
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.CreateWorkflowTransportAction;
import org.opensearch.flowframework.transport.GetWorkflowAction;
Expand All @@ -41,6 +42,8 @@
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction;
import org.opensearch.flowframework.transport.SearchWorkflowAction;
import org.opensearch.flowframework.transport.SearchWorkflowStateAction;
import org.opensearch.flowframework.transport.SearchWorkflowStateTransportAction;
import org.opensearch.flowframework.transport.SearchWorkflowTransportAction;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
Expand Down Expand Up @@ -130,7 +133,8 @@ public List<RestHandler> getRestHandlers(
new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting),
new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting)
new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting),
new RestSearchWorkflowStateAction(flowFrameworkFeatureEnabledSetting)
);
}

Expand All @@ -141,7 +145,8 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class),
new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class),
new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class),
new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class)
new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class),
new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ private CommonValue() {}
public static final String WORKFLOW_STEP_NAME = "workflow_step_name";
/** The field name for the step ID where a resource is created */
public static final String WORKFLOW_STEP_ID = "workflow_step_id";
/** The field name for the resource type */
public static final String RESOURCE_TYPE = "resource_type";
/** The field name for the resource id */
public static final String RESOURCE_ID = "resource_id";
/** The tools' field for an agent */
public static final String TOOLS_FIELD = "tools";
/** The memory field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
Expand Down Expand Up @@ -500,7 +501,12 @@ public void updateResourceInStateIndex(
String resourceId,
ActionListener<UpdateResponse> listener
) throws IOException {
ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceId);
ResourceCreated newResource = new ResourceCreated(
workflowStepName,
nodeId,
WorkflowResources.getResourceByWorkflowStep(workflowStepName),
resourceId
);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID;
import static org.opensearch.flowframework.common.CommonValue.RESOURCE_TYPE;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_NAME;

Expand All @@ -36,17 +37,20 @@ public class ResourceCreated implements ToXContentObject, Writeable {

private final String workflowStepName;
private final String workflowStepId;
private final String resourceType;
private final String resourceId;

/**
* Create this resources created object with given workflow step name, ID and resource ID.
* @param workflowStepName The workflow step name associating to the step where it was created
* @param workflowStepId The workflow step ID associating to the step where it was created
* @param resourceType The resource type
* @param resourceId The resources ID for relating to the created resource
*/
public ResourceCreated(String workflowStepName, String workflowStepId, String resourceId) {
public ResourceCreated(String workflowStepName, String workflowStepId, String resourceType, String resourceId) {
this.workflowStepName = workflowStepName;
this.workflowStepId = workflowStepId;
this.resourceType = resourceType;
this.resourceId = resourceId;
}

Expand All @@ -58,6 +62,7 @@ public ResourceCreated(String workflowStepName, String workflowStepId, String re
public ResourceCreated(StreamInput input) throws IOException {
this.workflowStepName = input.readString();
this.workflowStepId = input.readString();
this.resourceType = input.readString();
this.resourceId = input.readString();
}

Expand All @@ -66,14 +71,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
XContentBuilder xContentBuilder = builder.startObject()
.field(WORKFLOW_STEP_NAME, workflowStepName)
.field(WORKFLOW_STEP_ID, workflowStepId)
.field(WorkflowResources.getResourceByWorkflowStep(workflowStepName), resourceId);
.field(RESOURCE_TYPE, resourceType)
.field(RESOURCE_ID, resourceId);
return xContentBuilder.endObject();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(workflowStepName);
out.writeString(workflowStepId);
out.writeString(resourceType);
out.writeString(resourceId);
}

Expand All @@ -86,6 +93,15 @@ public String resourceId() {
return resourceId;
}

/**
* Gets the resource type.
*
* @return the resource type.
*/
public String resourceType() {
return resourceType;
}

/**
* Gets the workflow step name associated to the created resource
*
Expand Down Expand Up @@ -114,6 +130,7 @@ public String workflowStepId() {
public static ResourceCreated parse(XContentParser parser) throws IOException {
String workflowStepName = null;
String workflowStepId = null;
String resourceType = null;
String resourceId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
Expand All @@ -128,15 +145,14 @@ public static ResourceCreated parse(XContentParser parser) throws IOException {
case WORKFLOW_STEP_ID:
workflowStepId = parser.text();
break;
case RESOURCE_TYPE:
resourceType = parser.text();
break;
case RESOURCE_ID:
resourceId = parser.text();
break;
default:
if (!isValidFieldName(fieldName)) {
throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object.");
} else {
if (fieldName.equals(WorkflowResources.getResourceByWorkflowStep(workflowStepName))) {
resourceId = parser.text();
}
break;
}
throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object.");
}
}
if (workflowStepName == null) {
Expand All @@ -147,25 +163,25 @@ public static ResourceCreated parse(XContentParser parser) throws IOException {
logger.error("Resource created object failed parsing: workflowStepId: {}", workflowStepId);
throw new FlowFrameworkException("A ResourceCreated object requires workflowStepId", RestStatus.BAD_REQUEST);
}
if (resourceType == null) {
logger.error("Resource created object failed parsing: resourceType: {}", resourceType);
throw new FlowFrameworkException("A ResourceCreated object requires resourceType", RestStatus.BAD_REQUEST);
}
if (resourceId == null) {
logger.error("Resource created object failed parsing: resourceId: {}", resourceId);
throw new FlowFrameworkException("A ResourceCreated object requires resourceId", RestStatus.BAD_REQUEST);
}
return new ResourceCreated(workflowStepName, workflowStepId, resourceId);
}

private static boolean isValidFieldName(String fieldName) {
return (WORKFLOW_STEP_NAME.equals(fieldName)
|| WORKFLOW_STEP_ID.equals(fieldName)
|| WorkflowResources.getAllResourcesCreated().contains(fieldName));
return new ResourceCreated(workflowStepName, workflowStepId, resourceType, resourceId);
}

@Override
public String toString() {
return "resources_Created [workflow_step_name= "
+ workflowStepName
+ ", workflow_step_id= "
+ workflowStepName
+ workflowStepId
+ ", resource_type= "
+ resourceType
+ ", resource_id= "
+ resourceId
+ "]";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.rest;

import com.google.common.collect.ImmutableList;
import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.transport.SearchWorkflowStateAction;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;

/**
* Rest Action to facilitate requests to search workflow states
*/
public class RestSearchWorkflowStateAction extends AbstractSearchWorkflowAction<WorkflowState> {

private static final String SEARCH_WORKFLOW_STATE_ACTION = "search_workflow_state_action";
private static final String SEARCH_WORKFLOW_STATE_PATH = WORKFLOW_URI + "/state/_search";

/**
* Instantiates a new RestSearchWorkflowStateAction
*
* @param flowFrameworkFeatureEnabledSetting Whether this API is enabled
*/
public RestSearchWorkflowStateAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) {
super(
ImmutableList.of(SEARCH_WORKFLOW_STATE_PATH),
WORKFLOW_STATE_INDEX,
WorkflowState.class,
SearchWorkflowStateAction.INSTANCE,
flowFrameworkFeatureEnabledSetting
);
}

@Override
public String getName() {
return SEARCH_WORKFLOW_STATE_ACTION;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
public class GetWorkflowStateResponse extends ActionResponse implements ToXContentObject {

/** The workflow state */
public WorkflowState workflowState;
private final WorkflowState workflowState;
/** Flag to indicate if the entire state should be returned */
public boolean allStatus;
private final boolean allStatus;

/**
* Instantiates a new GetWorkflowStateResponse from an input stream
Expand All @@ -44,6 +44,7 @@ public GetWorkflowStateResponse(StreamInput in) throws IOException {
* @param allStatus whether to return all fields in state index
*/
public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus) {
this.allStatus = allStatus;
if (allStatus) {
this.workflowState = workflowState;
} else {
Expand All @@ -58,10 +59,27 @@ public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus)
@Override
public void writeTo(StreamOutput out) throws IOException {
workflowState.writeTo(out);
out.writeBoolean(allStatus);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return workflowState.toXContent(xContentBuilder, params);
}

/**
* Gets the workflow state.
* @return the workflow state
*/
public WorkflowState getWorkflowState() {
return workflowState;
}

/**
* Gets whether to return the entire state.
* @return true if the entire state should be returned
*/
public boolean isAllStatus() {
return allStatus;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.transport;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchResponse;

import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX;

/**
* External Action for public facing RestSearchWorkflowStateAction
*/
public class SearchWorkflowStateAction extends ActionType<SearchResponse> {

/** The name of this action */
public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow_state/search";
/** An instance of this action */
public static final SearchWorkflowStateAction INSTANCE = new SearchWorkflowStateAction();

private SearchWorkflowStateAction() {
super(NAME, SearchResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.transport;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/**
* Transport Action to search workflow states
*/
public class SearchWorkflowStateTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {

private Client client;

/**
* Intantiates a new SearchWorkflowStateTransportAction
* @param transportService the TransportService
* @param actionFilters action filters
* @param client The client used to make the request to OS
*/
@Inject
public SearchWorkflowStateTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SearchWorkflowStateAction.NAME, transportService, actionFilters, SearchRequest::new);
this.client = client;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// TODO: AccessController should take care of letting the user with right permission to view the workflow
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, ActionListener.runBefore(actionListener, () -> context.restore()));
} catch (Exception e) {
actionListener.onFailure(e);
}
}
}
Loading

0 comments on commit 7d1dac1

Please sign in to comment.