Skip to content

Commit

Permalink
[Backport 2.x] Manual Back Port of #350 (#372)
Browse files Browse the repository at this point in the history
Adds deploy model flag support for local model registration, fixes integration tests (#350)

* Fixing local model integration test



* Added deploy model flag support for local model registration, added associated integration test



* Fixing comment



* Fixing deprovision workflow transport action, removing use of template, ascertaining deprovision sequence from created resources



* Removing rest status checks for deprovision API tests



* Increasing wait time for deprovision status



* Removing sdeprovision status checks for model deployment tests



* increasing timeout for local model registration test template



* Reverting timeout increase, setting ML Commons native memory threshold to 100 to avoid opening circuit breaker



* Passing an action listener to retryableGetMlTask



* Addressing PR comments, preserving order of resource map



* Testing if a wait time after deprovisioning will mitigate circuit breaker issues



* Increasing mlconfig index creation wait time



* Combining local model registration tests into one



* removing resource map from deprovision workflow transport action



* Fixing getResourceFromDeprovisionNOde and tests



* Separating out local model registration tests, using ml jvm heap memory setting instead of native memory heap setting



* Testing : removing second local model registration test



* Reducing model registration tests, testing local model registration with deployed flag, testing remote model registration with deploy step



* Removing suffix from simulated deploy model step



---------

Signed-off-by: Joshua Palis <jpalis@amazon.com>
  • Loading branch information
joshpalis authored Jan 5, 2024
1 parent 5e4efcc commit efad064
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 322 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -24,31 +24,25 @@
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD;
import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep;
Expand All @@ -65,84 +59,50 @@ public class DeprovisionWorkflowTransportAction extends HandledTransportAction<W

private final ThreadPool threadPool;
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
private final WorkflowStepFactory workflowStepFactory;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final EncryptorUtils encryptorUtils;

/**
* Instantiates a new ProvisionWorkflowTransportAction
* @param transportService The TransportService
* @param actionFilters action filters
* @param threadPool The OpenSearch thread pool
* @param client The node client to retrieve a stored use case template
* @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes
* @param workflowStepFactory The factory instantiating workflow steps
* @param flowFrameworkIndicesHandler Class to handle all internal system indices actions
* @param encryptorUtils Utility class to handle encryption/decryption
*/
@Inject
public DeprovisionWorkflowTransportAction(
TransportService transportService,
ActionFilters actionFilters,
ThreadPool threadPool,
Client client,
WorkflowProcessSorter workflowProcessSorter,
WorkflowStepFactory workflowStepFactory,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
EncryptorUtils encryptorUtils
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(DeprovisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.threadPool = threadPool;
this.client = client;
this.workflowProcessSorter = workflowProcessSorter;
this.workflowStepFactory = workflowStepFactory;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.encryptorUtils = encryptorUtils;
}

@Override
protected void doExecute(Task task, WorkflowRequest request, ActionListener<WorkflowResponse> listener) {
// Retrieve use case template from global context
String workflowId = request.getWorkflowId();
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);

// Stash thread context to interact with system index
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(response -> {
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
context.restore();

if (!response.isExists()) {
listener.onFailure(
new FlowFrameworkException(
"Failed to retrieve template (" + workflowId + ") from global context.",
RestStatus.NOT_FOUND
)
);
return;
}

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

// Decrypt template
template = encryptorUtils.decryptTemplateCredentials(template);

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
workflowProcessSorter.validate(provisionProcessSequence);

// We have a valid template and sorted nodes, get the created resources
getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener);
// Retrieve resources from workflow state and deprovision
executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener);
}, exception -> {
if (exception instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for workflow : " + workflowId);
listener.onFailure(exception);
} else {
logger.error("Failed to retrieve template from global context.", exception);
listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
}
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception)));
}));
} catch (Exception e) {
String message = "Failed to retrieve template from global context.";
Expand All @@ -151,64 +111,38 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}
}

private void getResourcesAndExecute(
String workflowId,
List<ProcessNode> provisionProcessSequence,
ActionListener<WorkflowResponse> listener
) {
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
// Get a map of step id to created resources
final Map<String, ResourceCreated> resourceMap = response.getWorkflowState()
.resourcesCreated()
.stream()
.collect(Collectors.toMap(ResourceCreated::workflowStepId, Function.identity()));

// Now finally do the deprovision
executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener);
}, exception -> {
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception)));
}));
}

private void executeDeprovisionSequence(
String workflowId,
Map<String, ResourceCreated> resourceMap,
List<ProcessNode> provisionProcessSequence,
List<ResourceCreated> resourcesCreated,
ActionListener<WorkflowResponse> listener
) {

// Create a list of ProcessNodes with the corresponding deprovision workflow steps
List<ProcessNode> deprovisionProcessSequence = provisionProcessSequence.stream()
// Only include nodes that created a resource
.filter(pn -> resourceMap.containsKey(pn.id()))
// Create a new ProcessNode with a deprovision step
.map(pn -> {
String stepName = pn.workflowStep().getName();
String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName);
// Unimplemented steps presently return null, so skip
if (deprovisionStep == null) {
return null;
}
// New ID is old ID with deprovision added
String deprovisionStepId = pn.id() + DEPROVISION_SUFFIX;
return new ProcessNode(
List<ProcessNode> deprovisionProcessSequence = new ArrayList<>();
for (ResourceCreated resource : resourcesCreated) {
String workflowStepId = resource.workflowStepId();

String stepName = resource.workflowStepName();
String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName);
// Unimplemented steps presently return null, so skip
if (deprovisionStep == null) {
continue;
}
// New ID is old ID with deprovision added
String deprovisionStepId = workflowStepId + DEPROVISION_SUFFIX;
deprovisionProcessSequence.add(
new ProcessNode(
deprovisionStepId,
workflowStepFactory.createStep(deprovisionStep),
Collections.emptyMap(),
new WorkflowData(
Map.of(getResourceByWorkflowStep(stepName), resourceMap.get(pn.id()).resourceId()),
workflowId,
deprovisionStepId
),
new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId),
Collections.emptyList(),
this.threadPool,
pn.nodeTimeout()
);
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
TimeValue.ZERO
)
);
}

// Deprovision in reverse order of provisioning to minimize risk of dependencies
Collections.reverse(deprovisionProcessSequence);
logger.info("Deprovisioning steps: {}", deprovisionProcessSequence.stream().map(ProcessNode::id).collect(Collectors.joining(", ")));
Expand All @@ -219,7 +153,7 @@ private void executeDeprovisionSequence(
Iterator<ProcessNode> iter = deprovisionProcessSequence.iterator();
while (iter.hasNext()) {
ProcessNode deprovisionNode = iter.next();
ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourceMap);
ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourcesCreated);
String resourceNameAndId = getResourceNameAndId(resource);
CompletableFuture<WorkflowData> deprovisionFuture = deprovisionNode.execute();
try {
Expand Down Expand Up @@ -265,7 +199,7 @@ private void executeDeprovisionSequence(
}
// Get corresponding resources
List<ResourceCreated> remainingResources = deprovisionProcessSequence.stream()
.map(pn -> getResourceFromDeprovisionNode(pn, resourceMap))
.map(pn -> getResourceFromDeprovisionNode(pn, resourcesCreated))
.collect(Collectors.toList());
logger.info("Resources remaining: {}", remainingResources);
updateWorkflowState(workflowId, remainingResources, listener);
Expand Down Expand Up @@ -322,10 +256,18 @@ private void updateWorkflowState(
}
}

private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, Map<String, ResourceCreated> resourceMap) {
private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, List<ResourceCreated> resourcesCreated) {
String deprovisionId = deprovisionNode.id();
int pos = deprovisionId.indexOf(DEPROVISION_SUFFIX);
return pos > 0 ? resourceMap.get(deprovisionId.substring(0, pos)) : null;
ResourceCreated resource = null;
if (pos > 0) {
for (ResourceCreated resourceCreated : resourcesCreated) {
if (resourceCreated.workflowStepId().equals(deprovisionId.substring(0, pos))) {
resource = resourceCreated;
}
}
}
return resource;
}

private static String getResourceNameAndId(ResourceCreated resource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

/**
Expand Down Expand Up @@ -66,13 +64,15 @@ protected AbstractRetryableWorkflowStep(
* @param future the workflow step future
* @param taskId the ml task id
* @param workflowStep the workflow step which requires a retry get ml task functionality
* @param mlTaskListener the ML Task Listener
*/
protected void retryableGetMlTask(
String workflowId,
String nodeId,
CompletableFuture<WorkflowData> future,
String taskId,
String workflowStep
String workflowStep,
ActionListener<MLTask> mlTaskListener
) {
AtomicInteger retries = new AtomicInteger();
CompletableFuture.runAsync(() -> {
Expand All @@ -91,46 +91,37 @@ protected void retryableGetMlTask(
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, id),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
mlTaskListener.onResponse(response);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
mlTaskListener.onFailure(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
mlTaskListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
break;
case FAILED:
case COMPLETED_WITH_ERROR:
String errorMessage = workflowStep + " failed with error : " + response.getError();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
break;
case CANCELLED:
errorMessage = workflowStep + " task was cancelled.";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
break;
default:
// Task started or running, do nothing
}
}, exception -> {
String errorMessage = workflowStep + " failed with error : " + exception.getMessage();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
}));
// Wait long enough for future to possibly complete
try {
Expand All @@ -143,7 +134,7 @@ protected void retryableGetMlTask(
if (!future.isDone()) {
String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
}
}, threadPool.executor(PROVISION_THREAD_POOL));
}
Expand Down
Loading

0 comments on commit efad064

Please sign in to comment.