diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 5e8509373..deeabdd76 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -31,6 +31,7 @@ import java.util.Locale; import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -91,8 +92,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); Template template = Template.parse(request.content().utf8ToString()); boolean dryRun = request.paramAsBoolean(DRY_RUN, false); + boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, requestTimeout, maxWorkflows); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, provision, requestTimeout, maxWorkflows); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 6ca1c4661..765c9cae5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -135,7 +135,28 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + if (request.isProvision()) { + logger.info("provision parameter"); + WorkflowRequest workflowRequest = new WorkflowRequest(globalContextResponse.getId(), null); + client.execute( + ProvisionWorkflowAction.INSTANCE, + workflowRequest, + ActionListener.wrap(provisionResponse -> { + listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + logger.error("Failed to send back provision workflow exception", exception); + }) + ); + } else { + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + } }, exception -> { logger.error("Failed to save workflow state : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { @@ -246,5 +267,4 @@ private void validateWorkflows(Template template) throws Exception { workflowProcessSorter.validateGraph(sortedNodes); } } - } diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index d049be8f6..057f13d01 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -38,6 +38,11 @@ public class WorkflowRequest extends ActionRequest { */ private boolean dryRun; + /** + * Provision flag + */ + private boolean provision; + /** * Timeout for request */ @@ -54,7 +59,7 @@ public class WorkflowRequest extends ActionRequest { * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, false, null, null); + this(workflowId, template, false, false, null, null); } /** @@ -70,7 +75,7 @@ public WorkflowRequest( @Nullable TimeValue requestTimeout, @Nullable Integer maxWorkflows ) { - this(workflowId, template, false, requestTimeout, maxWorkflows); + this(workflowId, template, false, false, requestTimeout, maxWorkflows); } /** @@ -78,6 +83,7 @@ public WorkflowRequest( * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow * @param dryRun flag to indicate if validation is necessary + * @param provision flag to indicate if provision is necessary * @param requestTimeout timeout of the request * @param maxWorkflows max number of workflows */ @@ -85,12 +91,14 @@ public WorkflowRequest( @Nullable String workflowId, @Nullable Template template, boolean dryRun, + boolean provision, @Nullable TimeValue requestTimeout, @Nullable Integer maxWorkflows ) { this.workflowId = workflowId; this.template = template; this.dryRun = dryRun; + this.provision = provision; this.requestTimeout = requestTimeout; this.maxWorkflows = maxWorkflows; } @@ -106,6 +114,7 @@ public WorkflowRequest(StreamInput in) throws IOException { String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); this.dryRun = in.readBoolean(); + this.provision = in.readBoolean(); this.requestTimeout = in.readOptionalTimeValue(); this.maxWorkflows = in.readOptionalInt(); } @@ -136,6 +145,14 @@ public boolean isDryRun() { return this.dryRun; } + /** + * Gets the provision flag + * @return the provision boolean + */ + public boolean isProvision() { + return this.provision; + } + /** * Gets the timeout of the request * @return the requestTimeout @@ -158,6 +175,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); out.writeBoolean(dryRun); + out.writeBoolean(provision); out.writeOptionalTimeValue(requestTimeout); out.writeOptionalInt(maxWorkflows); } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 6856a2122..2e67b59d8 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -29,6 +30,7 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -52,8 +54,9 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -89,7 +92,16 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); + + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + WorkflowStepFactory factory = new WorkflowStepFactory( + Settings.EMPTY, + clusterService, + client, + mlClient, + flowFrameworkIndicesHandler + ); + this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), @@ -129,7 +141,16 @@ public void setUp() throws Exception { ); } - public void testFailedDryRunValidation() { + public void testDryRunValidation_withoutProvision_Success() { + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, true, false, null, null); + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + } + + public void testDryRunValidation_Failed() { WorkflowNode createConnector = new WorkflowNode( "workflow_step_1", @@ -183,7 +204,7 @@ public void testFailedDryRunValidation() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, false, null, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -198,6 +219,7 @@ public void testMaxWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -234,6 +256,7 @@ public void testFailedToCreateNewWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -271,6 +294,7 @@ public void testCreateNewWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -352,4 +376,166 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } + + public void testCreateWorkflow_withDryRun_withProvision_Success() { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + true, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + } + + public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() { + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + true, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + responseListener.onFailure(new Exception("failed")); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("failed", exceptionCaptor.getValue().getMessage()); + } + + private Template generateValidTemplate() { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(), + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + WorkflowResources.DEPLOY_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + + Template validTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + return validTemplate; + } }