From 5152ad31be2c0d5201cf4350846140065bd65e21 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 25 Oct 2023 20:02:18 +0000 Subject: [PATCH 01/19] Simplifying Template format, removing operations, resources created, user outputs Signed-off-by: Joshua Palis --- .../flowframework/model/Template.java | 127 +----------------- .../resources/mappings/global-context.json | 9 -- .../flowframework/model/TemplateTests.java | 18 +-- .../rest/RestCreateWorkflowActionTests.java | 6 +- .../CreateWorkflowTransportActionTests.java | 6 +- ...ProvisionWorkflowTransportActionTests.java | 6 +- .../WorkflowRequestResponseTests.java | 6 +- 7 files changed, 10 insertions(+), 168 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a1da67a4d..a1a526153 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -25,7 +25,6 @@ import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. @@ -38,8 +37,6 @@ public class Template implements ToXContentObject { public static final String DESCRIPTION_FIELD = "description"; /** The template field name for template use case */ public static final String USE_CASE_FIELD = "use_case"; - /** The template field name for template operations */ - public static final String OPERATIONS_FIELD = "operations"; /** The template field name for template version information */ public static final String VERSION_FIELD = "version"; /** The template field name for template version */ @@ -48,20 +45,13 @@ public class Template implements ToXContentObject { public static final String COMPATIBILITY_FIELD = "compatibility"; /** The template field name for template workflows */ public static final String WORKFLOWS_FIELD = "workflows"; - /** The template field name for template user outputs */ - public static final String USER_OUTPUTS_FIELD = "user_outputs"; - /** The template field name for template resources created */ - public static final String RESOURCES_CREATED_FIELD = "resources_created"; private final String name; private final String description; private final String useCase; // probably an ENUM actually - private final List operations; // probably an ENUM actually private final Version templateVersion; private final List compatibilityVersion; private final Map workflows; - private final Map userOutputs; - private final Map resourcesCreated; /** * Instantiate the object representing a use case template @@ -69,33 +59,24 @@ public class Template implements ToXContentObject { * @param name The template's name * @param description A description of the template's use case * @param useCase A string defining the internal use case type - * @param operations Expected operations of this template. Should match defined workflows. * @param templateVersion The version of this template * @param compatibilityVersion OpenSearch version compatibility of this template * @param workflows Workflow graph definitions corresponding to the defined operations. - * @param userOutputs A map of essential API responses for backend to use and lookup. - * @param resourcesCreated A map of all the resources created. */ public Template( String name, String description, String useCase, - List operations, Version templateVersion, List compatibilityVersion, - Map workflows, - Map userOutputs, - Map resourcesCreated + Map workflows ) { this.name = name; this.description = description; this.useCase = useCase; - this.operations = List.copyOf(operations); this.templateVersion = templateVersion; this.compatibilityVersion = List.copyOf(compatibilityVersion); this.workflows = Map.copyOf(workflows); - this.userOutputs = Map.copyOf(userOutputs); - this.resourcesCreated = Map.copyOf(resourcesCreated); } @Override @@ -104,11 +85,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(NAME_FIELD, this.name); xContentBuilder.field(DESCRIPTION_FIELD, this.description); xContentBuilder.field(USE_CASE_FIELD, this.useCase); - xContentBuilder.startArray(OPERATIONS_FIELD); - for (String op : this.operations) { - xContentBuilder.value(op); - } - xContentBuilder.endArray(); if (this.templateVersion != null || !this.compatibilityVersion.isEmpty()) { xContentBuilder.startObject(VERSION_FIELD); @@ -131,18 +107,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } xContentBuilder.endObject(); - xContentBuilder.startObject(USER_OUTPUTS_FIELD); - for (Entry e : userOutputs.entrySet()) { - xContentBuilder.field(e.getKey(), e.getValue()); - } - xContentBuilder.endObject(); - - xContentBuilder.startObject(RESOURCES_CREATED_FIELD); - for (Entry e : resourcesCreated.entrySet()) { - xContentBuilder.field(e.getKey(), e.getValue()); - } - xContentBuilder.endObject(); - return xContentBuilder.endObject(); } @@ -157,12 +121,9 @@ public static Template parse(XContentParser parser) throws IOException { String name = null; String description = ""; String useCase = ""; - List operations = new ArrayList<>(); Version templateVersion = null; List compatibilityVersion = new ArrayList<>(); Map workflows = new HashMap<>(); - Map userOutputs = new HashMap<>(); - Map resourcesCreated = new HashMap<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -178,12 +139,6 @@ public static Template parse(XContentParser parser) throws IOException { case USE_CASE_FIELD: useCase = parser.text(); break; - case OPERATIONS_FIELD: - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - operations.add(parser.text()); - } - break; case VERSION_FIELD: ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -212,42 +167,6 @@ public static Template parse(XContentParser parser) throws IOException { workflows.put(workflowFieldName, Workflow.parse(parser)); } break; - case USER_OUTPUTS_FIELD: - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String userOutputsFieldName = parser.currentName(); - switch (parser.nextToken()) { - case VALUE_STRING: - userOutputs.put(userOutputsFieldName, parser.text()); - break; - case START_OBJECT: - userOutputs.put(userOutputsFieldName, parseStringToStringMap(parser)); - break; - default: - throw new IOException("Unable to parse field [" + userOutputsFieldName + "] in a user_outputs object."); - } - } - break; - - case RESOURCES_CREATED_FIELD: - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String resourcesCreatedField = parser.currentName(); - switch (parser.nextToken()) { - case VALUE_STRING: - resourcesCreated.put(resourcesCreatedField, parser.text()); - break; - case START_OBJECT: - resourcesCreated.put(resourcesCreatedField, parseStringToStringMap(parser)); - break; - default: - throw new IOException( - "Unable to parse field [" + resourcesCreatedField + "] in a resources_created object." - ); - } - } - break; - default: throw new IOException("Unable to parse field [" + fieldName + "] in a template object."); } @@ -256,17 +175,7 @@ public static Template parse(XContentParser parser) throws IOException { throw new IOException("An template object requires a name."); } - return new Template( - name, - description, - useCase, - operations, - templateVersion, - compatibilityVersion, - workflows, - userOutputs, - resourcesCreated - ); + return new Template(name, description, useCase, templateVersion, compatibilityVersion, workflows); } /** @@ -338,14 +247,6 @@ public String useCase() { return useCase; } - /** - * Operations this use case supports - * @return the operations - */ - public List operations() { - return operations; - } - /** * The version of this template * @return the templateVersion @@ -363,29 +264,13 @@ public List compatibilityVersion() { } /** - * Workflows encoded in this template, generally corresponding to the operations returned by {@link #operations()}. + * Workflows encoded in this template * @return the workflows */ public Map workflows() { return workflows; } - /** - * A map of essential API responses - * @return the userOutputs - */ - public Map userOutputs() { - return userOutputs; - } - - /** - * A map of all the resources created - * @return the resources created - */ - public Map resourcesCreated() { - return resourcesCreated; - } - @Override public String toString() { return "Template [name=" @@ -394,18 +279,12 @@ public String toString() { + description + ", useCase=" + useCase - + ", operations=" - + operations + ", templateVersion=" + templateVersion + ", compatibilityVersion=" + compatibilityVersion + ", workflows=" + workflows - + ", userOutputs=" - + userOutputs - + ", resourcesCreated=" - + resourcesCreated + "]"; } } diff --git a/src/main/resources/mappings/global-context.json b/src/main/resources/mappings/global-context.json index bb1256dee..5190d4c95 100644 --- a/src/main/resources/mappings/global-context.json +++ b/src/main/resources/mappings/global-context.json @@ -22,9 +22,6 @@ "use_case": { "type": "keyword" }, - "operations": { - "type": "keyword" - }, "version": { "type": "nested", "properties": { @@ -38,12 +35,6 @@ }, "workflows": { "type": "object" - }, - "user_outputs": { - "type": "object" - }, - "resources_created": { - "type": "object" } } } diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java index 78548c46a..695a31ca4 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.model; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.Version; import org.opensearch.test.OpenSearchTestCase; @@ -19,13 +17,9 @@ public class TemplateTests extends OpenSearchTestCase { - private final Logger logger = LogManager.getLogger(TemplateTests.class); - private String expectedTemplate = - "{\"name\":\"test\",\"description\":\"a test template\",\"use_case\":\"test use case\",\"operations\":[\"operation\"],\"version\":{\"template\":\"1.2.3\",\"compatibility\":[\"4.5.6\",\"7.8.9\"]}," - + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}," - + "\"user_outputs\":{\"responsesMapKey\":{\"nestedKey\":\"nestedValue\"},\"responsesKey\":\"testValue\"}," - + "\"resources_created\":{\"resourcesMapKey\":{\"nestedKey\":\"nestedValue\"},\"resourcesKey\":\"resourceValue\"}}"; + "{\"name\":\"test\",\"description\":\"a test template\",\"use_case\":\"test use case\",\"version\":{\"template\":\"1.2.3\",\"compatibility\":[\"4.5.6\",\"7.8.9\"]}," + + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; @Override public void setUp() throws Exception { @@ -35,7 +29,6 @@ public void setUp() throws Exception { public void testTemplate() throws IOException { Version templateVersion = Version.fromString("1.2.3"); List compatibilityVersion = List.of(Version.fromString("4.5.6"), Version.fromString("7.8.9")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); @@ -47,18 +40,14 @@ public void testTemplate() throws IOException { "test", "a test template", "test use case", - List.of("operation"), templateVersion, compatibilityVersion, - Map.of("workflow", workflow), - Map.ofEntries(Map.entry("responsesKey", "testValue"), Map.entry("responsesMapKey", Map.of("nestedKey", "nestedValue"))), - Map.ofEntries(Map.entry("resourcesKey", "resourceValue"), Map.entry("resourcesMapKey", Map.of("nestedKey", "nestedValue"))) + Map.of("workflow", workflow) ); assertEquals("test", template.name()); assertEquals("a test template", template.description()); assertEquals("test use case", template.useCase()); - assertEquals(List.of("operation"), template.operations()); assertEquals(templateVersion, template.templateVersion()); assertEquals(compatibilityVersion, template.compatibilityVersion()); Workflow wf = template.workflows().get("workflow"); @@ -71,7 +60,6 @@ public void testTemplate() throws IOException { assertEquals("test", templateX.name()); assertEquals("a test template", templateX.description()); assertEquals("test use case", templateX.useCase()); - assertEquals(List.of("operation"), templateX.operations()); assertEquals(templateVersion, templateX.templateVersion()); assertEquals(compatibilityVersion, templateX.compatibilityVersion()); Workflow wfX = templateX.workflows().get("workflow"); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 11e6c61fa..141ea61b6 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -41,7 +41,6 @@ public class RestCreateWorkflowActionTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - List operations = List.of("operation"); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); @@ -55,12 +54,9 @@ public void setUp() throws Exception { "test", "description", "use case", - operations, templateVersion, compatibilityVersions, - Map.of("workflow", workflow), - Map.of("outputKey", "outputValue"), - Map.of("resourceKey", "resourceValue") + Map.of("workflow", workflow) ); // Invalid template configuration, wrong field name diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 1f7df11d9..dc3840d44 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -50,7 +50,6 @@ public void setUp() throws Exception { globalContextHandler ); - List operations = List.of("operation"); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); @@ -64,12 +63,9 @@ public void setUp() throws Exception { "test", "description", "use case", - operations, templateVersion, compatibilityVersions, - Map.of("workflow", workflow), - Map.of("outputKey", "outputValue"), - Map.of("resourceKey", "resourceValue") + Map.of("workflow", workflow) ); } diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 7c0ae9ef9..d4f37261a 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -66,7 +66,6 @@ public void setUp() throws Exception { workflowProcessSorter ); - List operations = List.of("operation"); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); @@ -80,12 +79,9 @@ public void setUp() throws Exception { "test", "description", "use case", - operations, templateVersion, compatibilityVersions, - Map.of("provision", workflow), - Map.of("outputKey", "outputValue"), - Map.of("resourceKey", "resourceValue") + Map.of("provision", workflow) ); ThreadPool clientThreadPool = mock(ThreadPool.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index 063490c41..057088aac 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -33,7 +33,6 @@ public class WorkflowRequestResponseTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); - List operations = List.of("operation"); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); @@ -47,12 +46,9 @@ public void setUp() throws Exception { "test", "description", "use case", - operations, templateVersion, compatibilityVersions, - Map.of("workflow", workflow), - Map.of("outputKey", "outputValue"), - Map.of("resourceKey", "resourceValue") + Map.of("workflow", workflow) ); } From 6ddd3b20924d8788f88f6ae0a15d1c0924058b42 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 27 Oct 2023 00:44:45 +0000 Subject: [PATCH 02/19] Initial commit, modifies use case template to seperate workflow inputs into previous_node_inputs and user_inputs, adds graph validation after topologically sorting a workflow into a list of ProcessNode Signed-off-by: Joshua Palis --- .../flowframework/model/WorkflowNode.java | 54 +++++++---- .../model/WorkflowStepValidator.java | 91 +++++++++++++++++++ .../model/WorkflowValidator.java | 87 ++++++++++++++++++ .../ProvisionWorkflowTransportAction.java | 9 +- .../workflow/CreateIndexStep.java | 4 +- .../workflow/CreateIngestPipelineStep.java | 2 +- .../flowframework/workflow/ProcessNode.java | 13 +++ .../workflow/WorkflowProcessSorter.java | 62 ++++++++++++- .../resources/mappings/workflow-steps.json | 58 ++++++++++++ .../model/TemplateTestJsonUtil.java | 2 +- .../flowframework/model/TemplateTests.java | 6 +- .../model/WorkflowNodeTests.java | 20 ++-- .../flowframework/model/WorkflowTests.java | 8 +- .../rest/RestCreateWorkflowActionTests.java | 4 +- .../CreateWorkflowTransportActionTests.java | 4 +- ...ProvisionWorkflowTransportActionTests.java | 4 +- .../WorkflowRequestResponseTests.java | 4 +- .../workflow/CreateIndexStepTests.java | 4 +- .../CreateIngestPipelineStepTests.java | 4 +- .../workflow/ProcessNodeTests.java | 7 +- 20 files changed, 388 insertions(+), 59 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java create mode 100644 src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java create mode 100644 src/main/resources/mappings/workflow-steps.json diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index e34c4ddec..d3fb56578 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -39,8 +39,10 @@ public class WorkflowNode implements ToXContentObject { public static final String ID_FIELD = "id"; /** The template field name for node type */ public static final String TYPE_FIELD = "type"; + /** The template field name for previous node inputs */ + public static final String PREVIOUS_NODE_INPUTS_FIELD = "previous_node_inputs"; /** The template field name for node inputs */ - public static final String INPUTS_FIELD = "inputs"; + public static final String USER_INPUTS_FIELD = "user_inputs"; /** The field defining processors in the inputs for search and ingest pipelines */ public static final String PROCESSORS_FIELD = "processors"; /** The field defining the timeout value for this node */ @@ -50,19 +52,22 @@ public class WorkflowNode implements ToXContentObject { private final String id; // unique id private final String type; // maps to a WorkflowStep - private final Map inputs; // maps to WorkflowData + private final Map previousNodeInputs; + private final Map userInputs; // maps to WorkflowData /** * Create this node with the id and type, and any user input. * * @param id A unique string identifying this node * @param type The type of {@link WorkflowStep} to create for the corresponding {@link ProcessNode} - * @param inputs Optional input to populate params in {@link WorkflowData} + * @param previousNodeInputs Optional input to identify inputs coming from predecessor nodes + * @param userInputs Optional input to populate params in {@link WorkflowData} */ - public WorkflowNode(String id, String type, Map inputs) { + public WorkflowNode(String id, String type, Map previousNodeInputs, Map userInputs) { this.id = id; this.type = type; - this.inputs = Map.copyOf(inputs); + this.previousNodeInputs = Map.copyOf(previousNodeInputs); + this.userInputs = Map.copyOf(userInputs); } @Override @@ -71,8 +76,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(ID_FIELD, this.id); xContentBuilder.field(TYPE_FIELD, this.type); - xContentBuilder.startObject(INPUTS_FIELD); - for (Entry e : inputs.entrySet()) { + xContentBuilder.field(PREVIOUS_NODE_INPUTS_FIELD); + buildStringToStringMap(xContentBuilder, previousNodeInputs); + + xContentBuilder.startObject(USER_INPUTS_FIELD); + for (Entry e : userInputs.entrySet()) { xContentBuilder.field(e.getKey()); if (e.getValue() instanceof String) { xContentBuilder.value(e.getValue()); @@ -107,7 +115,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static WorkflowNode parse(XContentParser parser) throws IOException { String id = null; String type = null; - Map inputs = new HashMap<>(); + Map previousNodeInputs = new HashMap<>(); + Map userInputs = new HashMap<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -120,16 +129,19 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { case TYPE_FIELD: type = parser.text(); break; - case INPUTS_FIELD: + case PREVIOUS_NODE_INPUTS_FIELD: + previousNodeInputs = parseStringToStringMap(parser); + break; + case USER_INPUTS_FIELD: ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String inputFieldName = parser.currentName(); switch (parser.nextToken()) { case VALUE_STRING: - inputs.put(inputFieldName, parser.text()); + userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - inputs.put(inputFieldName, parseStringToStringMap(parser)); + userInputs.put(inputFieldName, parseStringToStringMap(parser)); break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { @@ -137,13 +149,13 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { while (parser.nextToken() != XContentParser.Token.END_ARRAY) { processorList.add(PipelineProcessor.parse(parser)); } - inputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); + userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); } else { List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { mapList.add(parseStringToStringMap(parser)); } - inputs.put(inputFieldName, mapList.toArray(new Map[0])); + userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } break; default: @@ -159,7 +171,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { throw new IOException("An node object requires both an id and type field."); } - return new WorkflowNode(id, type, inputs); + return new WorkflowNode(id, type, previousNodeInputs, userInputs); } /** @@ -179,11 +191,19 @@ public String type() { } /** - * Return this node's input data + * Return this node's user input data + * @return the inputs + */ + public Map userInputs() { + return userInputs; + } + + /** + * Return this node's predecessor inputs * @return the inputs */ - public Map inputs() { - return inputs; + public Map previousNodeInputs() { + return previousNodeInputs; } @Override diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java new file mode 100644 index 000000000..e0c21ce0e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents the an object of workflow steps json which maps each step to expected inputs and outputs + */ +public class WorkflowStepValidator { + + /** Inputs field name */ + private static final String INPUTS_FIELD = "inputs"; + /** Outputs field name */ + private static final String OUTPUTS_FIELD = "outputs"; + + private List inputs; + private List outputs; + + /** + * Intantiate the object representing a Workflow Step validator + * @param inputs the workflow step inputs + * @param outputs the workflow step outputs + */ + public WorkflowStepValidator(List inputs, List outputs) { + this.inputs = inputs; + this.outputs = outputs; + } + + /** + * Parse raw json content into a WorkflowStepValidator instance + * @param parser json based content parser + * @return an instance of the WorkflowStepValidator + * @throws IOException if the content cannot be parsed correctly + */ + public static WorkflowStepValidator parse(XContentParser parser) throws IOException { + List parsedInputs = new ArrayList<>(); + List parsedOutputs = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case INPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + parsedInputs.add(parser.text()); + } + break; + case OUTPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + parsedOutputs.add(parser.text()); + } + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object."); + } + } + return new WorkflowStepValidator(parsedInputs, parsedOutputs); + } + + /** + * Get the required inputs + * @return the inputs + */ + public List getInputs() { + return inputs; + } + + /** + * Get the required outputs + * @return the outputs + */ + public List getOutputs() { + return outputs; + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java new file mode 100644 index 000000000..b2f849d3b --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents the workflow steps json which maps each step to expected inputs and outputs + */ +public class WorkflowValidator { + + private Map workflowStepValidators; + + /** + * Intantiate the object representing a Workflow validator + * @param workflowStepValidators a map of {@link WorkflowStepValidator} + */ + public WorkflowValidator(Map workflowStepValidators) { + this.workflowStepValidators = workflowStepValidators; + } + + /** + * Parse raw json content into a WorkflowValidator instance + * @param parser json based content parser + * @return an instance of the WorkflowValidator + * @throws IOException if the content cannot be parsed correctly + */ + private static WorkflowValidator parse(XContentParser parser) throws IOException { + + Map workflowStepValidators = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String type = parser.currentName(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + workflowStepValidators.put(type, WorkflowStepValidator.parse(parser)); + } + return new WorkflowValidator(workflowStepValidators); + } + + /** + * Parse a workflow step JSON file into a WorkflowValidator object + * + * @param file the file name of the workflow step json + * @return A {@link WorkflowValidator} represented by the JSON + * @throws IOException on failure to read and parse the json file + */ + public static WorkflowValidator parse(String file) throws IOException { + + URL url = WorkflowValidator.class.getClassLoader().getResource(file); + String json = Resources.toString(url, Charsets.UTF_8); + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return parse(parser); + } + + /** + * Get the map of WorkflowStepValidators + * @return the map of WorkflowStepValidators + */ + public Map getWorkflowStepValidators() { + return this.workflowStepValidators; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 45cac92bf..21c782d26 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -30,9 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; @@ -145,8 +143,11 @@ private void executeWorkflow(Workflow workflow, ActionListener workflowL // Attempt to topologically sort the workflow graph List processSequence = workflowProcessSorter.sortProcessNodes(workflow); - List> workflowFutureList = new ArrayList<>(); + // Validate the topologically sorted graph + workflowProcessSorter.validateGraph(processSequence); + + List> workflowFutureList = new ArrayList<>(); for (ProcessNode processNode : processSequence) { List predecessors = processNode.predecessors(); @@ -173,7 +174,7 @@ private void executeWorkflow(Workflow workflow, ActionListener workflowL } catch (IllegalArgumentException e) { workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); - } catch (CancellationException | CompletionException ex) { + } catch (Exception ex) { workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 2b2f7338d..52c942573 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -73,7 +73,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of("index-name", createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of("index_name", createIndexResponse.index()))); } @Override @@ -89,7 +89,7 @@ public void onFailure(Exception e) { for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); - index = (String) content.get("index-name"); + index = (String) content.get("index_name"); type = (String) content.get("type"); if (index != null && type != null && settings != null) { break; diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 4770b94a9..2c4f20b9b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -127,7 +127,7 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipelineId", putPipelineRequest.getId()))); + createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipeline_id", putPipelineRequest.getId()))); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 2f902755c..4f1ddf3cb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -16,6 +16,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -30,6 +31,7 @@ public class ProcessNode { private final String id; private final WorkflowStep workflowStep; + private final Map previousNodeInputs; private final WorkflowData input; private final List predecessors; private final ThreadPool threadPool; @@ -42,6 +44,7 @@ public class ProcessNode { * * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + * @param previousNodeInputs A map of expected inputs coming from predecessor nodes used in graph validation * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow * @param threadPool The OpenSearch thread pool @@ -50,6 +53,7 @@ public class ProcessNode { public ProcessNode( String id, WorkflowStep workflowStep, + Map previousNodeInputs, WorkflowData input, List predecessors, ThreadPool threadPool, @@ -57,6 +61,7 @@ public ProcessNode( ) { this.id = id; this.workflowStep = workflowStep; + this.previousNodeInputs = previousNodeInputs; this.input = input; this.predecessors = predecessors; this.threadPool = threadPool; @@ -79,6 +84,14 @@ public WorkflowStep workflowStep() { return workflowStep; } + /** + * Returns the node's expected predecessor node input + * @return the expected predecessor node inputs + */ + public Map previousNodeInputs() { + return previousNodeInputs; + } + /** * Returns the input data for this node. * @return the input data diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 71c44514e..342d8e98f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -14,6 +14,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; @@ -28,9 +29,9 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.flowframework.model.WorkflowNode.INPUTS_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; +import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; /** * Converts a workflow of nodes and edges into a topologically sorted list of Process Nodes. @@ -65,7 +66,7 @@ public List sortProcessNodes(Workflow workflow) { Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.inputs(), workflow.userParams()); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams()); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) @@ -74,7 +75,15 @@ public List sortProcessNodes(Workflow workflow) { .collect(Collectors.toList()); TimeValue nodeTimeout = parseTimeout(node); - ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, threadPool, nodeTimeout); + ProcessNode processNode = new ProcessNode( + node.id(), + step, + node.previousNodeInputs(), + data, + predecessorNodes, + threadPool, + nodeTimeout + ); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); } @@ -82,9 +91,52 @@ public List sortProcessNodes(Workflow workflow) { return nodes; } + /** + * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * @param processNodes A list of process nodes + * @throws Exception on validation failure + */ + public void validateGraph(List processNodes) throws Exception { + + WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + + // Iterate through process nodes in graph + for (ProcessNode processNode : processNodes) { + + // Get predecessor nodes types of this processNode + List predecessorNodeTypes = processNode.predecessors() + .stream() + .map(x -> x.workflowStep().getName()) + .collect(Collectors.toList()); + + // Compile a list of outputs from the predecessor nodes based on type + List predecessorOutputs = new ArrayList<>(); + for (String nodeType : predecessorNodeTypes) { + List nodeTypeOutputs = validator.getWorkflowStepValidators().get(nodeType).getOutputs(); + predecessorOutputs.addAll(nodeTypeOutputs); + } + + // Retrieve all the user input data from this node + List currentNodeUserInputs = new ArrayList(processNode.input().getContent().keySet()); + + // Combine both predecessor outputs and current node user inputs + List allInputs = new ArrayList<>(); + allInputs.addAll(predecessorOutputs); + allInputs.addAll(currentNodeUserInputs); + + // Retrieve list of required inputs from the current process node and compare + List expectedInputs = validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs(); + if (!allInputs.containsAll(expectedInputs)) { + expectedInputs.removeAll(allInputs); + throw new IllegalArgumentException("Invalid graph, missing the following required inputs : " + expectedInputs.toString()); + } + } + + } + private TimeValue parseTimeout(WorkflowNode node) { - String timeoutValue = (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); - String fieldName = String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD); + String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); + String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD); TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); if (timeValue.millis() < 0) { throw new IllegalArgumentException( diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json new file mode 100644 index 000000000..4ef3e17c3 --- /dev/null +++ b/src/main/resources/mappings/workflow-steps.json @@ -0,0 +1,58 @@ +{ + "create_index": { + "inputs":[ + "index_name", + "type" + ], + "outputs":[ + "index_name" + ] + }, + "create_ingest_pipeline": { + "inputs":[ + "id", + "description", + "type", + "model_id", + "input_field_name", + "output_field_name" + ], + "outputs":[ + "pipeline_id" + ] + }, + "create_connector": { + "inputs":[ + "name", + "description", + "version", + "protocol", + "parameters", + "credentials", + "actions" + ], + "outputs":[ + "connector_id" + ] + }, + "register_model": { + "inputs":[ + "function_name", + "name", + "description", + "connector_id" + ], + "outputs":[ + "model_id", + "register_model_status" + ] + }, + "deploy_model": { + "inputs":[ + "model_id" + ], + "outputs":[ + "deploy_model_status" + ] + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java index b38346b29..bd67454b1 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -44,7 +44,7 @@ public static String nodeWithTypeAndTimeout(String id, String type, String timeo + "\": \"" + type + "\", \"" - + WorkflowNode.INPUTS_FIELD + + WorkflowNode.USER_INPUTS_FIELD + "\": {\"" + WorkflowNode.NODE_TIMEOUT_FIELD + "\": \"" diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java index 695a31ca4..2bcfc9780 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -19,7 +19,7 @@ public class TemplateTests extends OpenSearchTestCase { private String expectedTemplate = "{\"name\":\"test\",\"description\":\"a test template\",\"use_case\":\"test use case\",\"version\":{\"template\":\"1.2.3\",\"compatibility\":[\"4.5.6\",\"7.8.9\"]}," - + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; + + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"user_inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"user_inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; @Override public void setUp() throws Exception { @@ -29,8 +29,8 @@ public void setUp() throws Exception { public void testTemplate() throws IOException { Version templateVersion = Version.fromString("1.2.3"); List compatibilityVersion = List.of(Version.fromString("4.5.6"), Version.fromString("7.8.9")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index 46d897b42..700e1d0d2 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -24,6 +24,7 @@ public void testNode() throws IOException { WorkflowNode nodeA = new WorkflowNode( "A", "a-type", + Map.of("foo", "field"), Map.ofEntries( Map.entry("foo", "a string"), Map.entry("bar", Map.of("key", "value")), @@ -33,7 +34,8 @@ public void testNode() throws IOException { ); assertEquals("A", nodeA.id()); assertEquals("a-type", nodeA.type()); - Map map = nodeA.inputs(); + assertEquals(Map.of("foo", "field"), nodeA.previousNodeInputs()); + Map map = nodeA.userInputs(); assertEquals("a string", (String) map.get("foo")); assertEquals(Map.of("key", "value"), (Map) map.get("bar")); assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); @@ -43,14 +45,16 @@ public void testNode() throws IOException { assertEquals(Map.of("key2", "value2"), pp[0].params()); // node equality is based only on ID - WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of("bar", "baz")); + WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); assertEquals(nodeA, nodeA2); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("A", "foo"), Map.of("baz", "qux")); assertNotEquals(nodeA, nodeB); String json = TemplateTestJsonUtil.parseToJson(nodeA); - assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":")); + logger.info("TESTING : " + json); + assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{\"foo\":\"field\"},")); + assertTrue(json.contains("\"user_inputs\":{")); assertTrue(json.contains("\"foo\":\"a string\"")); assertTrue(json.contains("\"baz\":[{\"A\":\"a\"},{\"B\":\"b\"}]")); assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); @@ -59,7 +63,9 @@ public void testNode() throws IOException { WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); assertEquals("a-type", nodeX.type()); - Map mapX = nodeX.inputs(); + Map previousNodeInputs = nodeX.previousNodeInputs(); + assertEquals("field", previousNodeInputs.get("foo")); + Map mapX = nodeX.userInputs(); assertEquals("a string", mapX.get("foo")); assertEquals(Map.of("key", "value"), mapX.get("bar")); assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); @@ -70,11 +76,11 @@ public void testNode() throws IOException { } public void testExceptions() throws IOException { - String badJson = "{\"badField\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}"; + String badJson = "{\"badField\":\"A\",\"type\":\"a-type\",\"user_inputs\":{\"foo\":\"bar\"}}"; IOException e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(badJson))); assertEquals("Unable to parse field [badField] in a node object.", e.getMessage()); - String missingJson = "{\"id\":\"A\",\"inputs\":{\"foo\":\"bar\"}}"; + String missingJson = "{\"id\":\"A\",\"user_inputs\":{\"foo\":\"bar\"}}"; e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); assertEquals("An node object requires both an id and type field.", e.getMessage()); } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java index db070da4b..03b57aaac 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java @@ -23,8 +23,8 @@ public void setUp() throws Exception { } public void testWorkflow() throws IOException { - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("A", "foo"), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); @@ -35,8 +35,8 @@ public void testWorkflow() throws IOException { assertEquals(List.of(edgeAB), workflow.edges()); String expectedJson = "{\"user_params\":{\"key\":\"value\"}," - + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}," - + "{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}]," + + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{},\"user_inputs\":{\"foo\":\"bar\"}}," + + "{\"id\":\"B\",\"type\":\"b-type\",\"previous_node_inputs\":{\"A\":\"foo\"},\"user_inputs\":{\"baz\":\"qux\"}}]," + "\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}"; String json = TemplateTestJsonUtil.parseToJson(workflow); assertEquals(expectedJson, json); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 141ea61b6..97c7921cb 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -43,8 +43,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index dc3840d44..c908592ff 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -52,8 +52,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index d4f37261a..7b774c763 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -68,8 +68,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index 057088aac..f20cc3026 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -35,8 +35,8 @@ public void setUp() throws Exception { super.setUp(); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 036714ba8..21209f65b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -74,7 +74,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index-name", "demo"), Map.entry("type", "knn"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn"))); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -103,7 +103,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio assertTrue(future.isDone() && !future.isCompletedExceptionally()); - Map outputData = Map.of("index-name", "demo"); + Map outputData = Map.of("index_name", "demo"); assertEquals(outputData, future.get().getContent()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 9dab2a8d7..039b0384f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -54,7 +54,7 @@ public void setUp() throws Exception { ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipelineId", "pipelineId"))); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId"))); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -109,7 +109,7 @@ public void testMissingData() throws InterruptedException { // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( Map.ofEntries( - Map.entry("id", "pipelineId"), + Map.entry("id", "pipeline_id"), Map.entry("description", "some description"), Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 1e421c58c..0cac95b49 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -67,6 +67,7 @@ public String getName() { return "test"; } }, + Map.of(), new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), List.of(successfulNode), testThreadPool, @@ -103,7 +104,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); + }, Map.of(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); assertEquals("B", nodeB.id()); assertEquals("test", nodeB.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeB.input()); @@ -129,7 +130,7 @@ public CompletableFuture execute(List data) { public String getName() { return "sleepy"; } - }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); + }, Map.of(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); assertEquals("Zzz", nodeZ.id()); assertEquals("sleepy", nodeZ.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeZ.input()); @@ -156,7 +157,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); + }, Map.of(), WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); assertEquals("E", nodeE.id()); assertEquals("test", nodeE.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeE.input()); From e116c102549d2c920ce27d20591778b5ef93cb1a Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 27 Oct 2023 22:17:31 +0000 Subject: [PATCH 03/19] Adding tests Signed-off-by: Joshua Palis --- .../model/WorkflowValidator.java | 2 +- .../workflow/WorkflowStepFactory.java | 31 ++----- .../model/WorkflowStepValidatorTests.java | 46 ++++++++++ .../model/WorkflowValidatorTests.java | 87 +++++++++++++++++++ 4 files changed, 142 insertions(+), 24 deletions(-) create mode 100644 src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java create mode 100644 src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java index b2f849d3b..af45ce77f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java @@ -43,7 +43,7 @@ public WorkflowValidator(Map workflowStepValidato * @return an instance of the WorkflowValidator * @throws IOException if the content cannot be parsed correctly */ - private static WorkflowValidator parse(XContentParser parser) throws IOException { + public static WorkflowValidator parse(XContentParser parser) throws IOException { Map workflowStepValidators = new HashMap<>(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 5aabd679f..48e26e5a6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -13,11 +13,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; - -import demo.DemoWorkflowStep; /** * Generates instances implementing {@link WorkflowStep}. @@ -44,25 +40,6 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); - - // TODO: These are from the demo class as placeholders, remove when demos are deleted - stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); - stepMap.put("demo_delay_5", new DemoWorkflowStep(5000)); - - // Use as a default until all the actual implementations are ready - stepMap.put("placeholder", new WorkflowStep() { - @Override - public CompletableFuture execute(List data) { - CompletableFuture future = new CompletableFuture<>(); - future.complete(WorkflowData.EMPTY); - return future; - } - - @Override - public String getName() { - return "placeholder"; - } - }); } /** @@ -78,4 +55,12 @@ public WorkflowStep createStep(String type) { // https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/43 return stepMap.get("placeholder"); } + + /** + * Gets the step map + * @return the step map + */ + public Map getStepMap() { + return this.stepMap; + } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java new file mode 100644 index 000000000..646e8f8af --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class WorkflowStepValidatorTests extends OpenSearchTestCase { + + private String validValidator; + private String invalidValidator; + + @Override + public void setUp() throws Exception { + super.setUp(); + validValidator = "{\"inputs\":[\"input_value\"],\"outputs\":[\"output_value\"]}"; + invalidValidator = "{\"inputs\":[\"input_value\"],\"invalid_field\":[\"output_value\"]}"; + } + + public void testParseWorkflowStepValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(validValidator); + WorkflowStepValidator workflowStepValidator = WorkflowStepValidator.parse(parser); + + assertEquals(1, workflowStepValidator.getInputs().size()); + assertEquals(1, workflowStepValidator.getOutputs().size()); + + assertEquals("input_value", workflowStepValidator.getInputs().get(0)); + assertEquals("output_value", workflowStepValidator.getOutputs().get(0)); + } + + public void testFailedParseWorkflowStepValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidValidator); + IOException ex = expectThrows(IOException.class, () -> WorkflowStepValidator.parse(parser)); + assertEquals("Unable to parse field [invalid_field] in a WorkflowStepValidator object.", ex.getMessage()); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java new file mode 100644 index 000000000..61d5238d4 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class WorkflowValidatorTests extends OpenSearchTestCase { + + private String validWorkflowStepJson; + private String invalidWorkflowStepJson; + + @Override + public void setUp() throws Exception { + super.setUp(); + validWorkflowStepJson = + "{\"workflow_step_1\":{\"inputs\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; + invalidWorkflowStepJson = + "{\"workflow_step_1\":{\"bad_field\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; + } + + public void testParseWorkflowValidator() throws IOException { + + XContentParser parser = TemplateTestJsonUtil.jsonToParser(validWorkflowStepJson); + WorkflowValidator validator = WorkflowValidator.parse(parser); + + assertEquals(2, validator.getWorkflowStepValidators().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_1")); + assertEquals(2, validator.getWorkflowStepValidators().get("workflow_step_1").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("workflow_step_1").getOutputs().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_2")); + assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getInputs().size()); + assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getOutputs().size()); + } + + public void testFailedParseWorkflowValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidWorkflowStepJson); + IOException ex = expectThrows(IOException.class, () -> WorkflowValidator.parse(parser)); + assertEquals("Unable to parse field [bad_field] in a WorkflowStepValidator object.", ex.getMessage()); + } + + public void testWorkflowStepFactoryHasValidators() throws IOException { + + ClusterService clusterService = mock(ClusterService.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + AdminClient adminClient = mock(AdminClient.class); + Client client = mock(Client.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); + + // Read in workflow-steps.json + WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json"); + + // Get all workflow step validator types + List registeredWorkflowValidatorTypes = new ArrayList(workflowValidator.getWorkflowStepValidators().keySet()); + + // Get all registered workflow step types in the workflow step factory + List registeredWorkflowStepTypes = new ArrayList(workflowStepFactory.getStepMap().keySet()); + + // Check if each registered step has a corresponding validator definition + assertEquals(registeredWorkflowValidatorTypes, registeredWorkflowStepTypes); + + } + +} From becc5107a64d0cd44e8fb58fd9d5e3a09739807f Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 27 Oct 2023 23:35:46 +0000 Subject: [PATCH 04/19] Adding validate graph test Signed-off-by: Joshua Palis --- .../workflow/WorkflowProcessSorterTests.java | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index f728dd7b1..0a6815b82 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,8 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -24,6 +26,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -199,4 +202,31 @@ public void testExceptions() throws IOException { ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); assertEquals("Edge destination C does not correspond to a node.", ex.getMessage()); } + + public void testValidateGraph() { + + // Create Register Model workflow node with missing connector_id field + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_1", + RegisterModelStep.NAME, + Map.of(), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_2", + RegisterModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "model_id")), + Map.of() + ); + WorkflowEdge edge = new WorkflowEdge(registerModel.id(), deployModel.id()); + Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); + + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> workflowProcessSorter.validateGraph(sortedProcessNodes) + ); + assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); + + } } From ee8f3cb87bfdbfbcd9b2795e46172e19a0e16aec Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 00:03:39 +0000 Subject: [PATCH 05/19] Addressing PR comments, moving sorting/validating prior to executing async, adding success test case for graph validation Signed-off-by: Joshua Palis --- .../ProvisionWorkflowTransportAction.java | 37 +++++++++-------- .../workflow/WorkflowProcessSorter.java | 17 ++++---- .../workflow/WorkflowProcessSorterTests.java | 41 ++++++++++++++++++- 3 files changed, 69 insertions(+), 26 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 21c782d26..c0bd9c592 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -97,12 +97,23 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + workflowProcessSorter.validateGraph(provisionProcessSequence); + // Respond to rest action then execute provisioning workflow async listener.onResponse(new WorkflowResponse(workflowId)); - executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW)); + executeWorkflowAsync(workflowId, provisionProcessSequence); + }, exception -> { - logger.error("Failed to retrieve template from global context.", exception); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof IllegalArgumentException) { + logger.error("Workflow validation failed for workflow : " + workflowId); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + } else { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } })); } catch (Exception e) { logger.error("Failed to retrieve template from global context.", e); @@ -113,9 +124,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence) { // TODO : Update Action listener type to State index Request ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { logger.info("Provisioning completed successuflly for workflow {}", workflowId); @@ -127,28 +138,22 @@ private void executeWorkflowAsync(String workflowId, Workflow workflow) { // TODO : Create State index request to update STATE entry status to FAILED }); try { - threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflow, provisionWorkflowListener); }); + threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); } catch (Exception exception) { provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } /** - * Topologically sorts a given workflow into a sequence of ProcessNodes and executes the workflow - * @param workflow The workflow to execute + * Executes the given workflow sequence + * @param workflowSequence The topologically sorted workflow to execute * @param workflowListener The listener that updates the status of a workflow execution */ - private void executeWorkflow(Workflow workflow, ActionListener workflowListener) { + private void executeWorkflow(List workflowSequence, ActionListener workflowListener) { try { - // Attempt to topologically sort the workflow graph - List processSequence = workflowProcessSorter.sortProcessNodes(workflow); - - // Validate the topologically sorted graph - workflowProcessSorter.validateGraph(processSequence); - List> workflowFutureList = new ArrayList<>(); - for (ProcessNode processNode : processSequence) { + for (ProcessNode processNode : workflowSequence) { List predecessors = processNode.predecessors(); logger.info( diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 342d8e98f..233539ef5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -19,6 +19,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -28,6 +29,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; @@ -110,22 +112,21 @@ public void validateGraph(List processNodes) throws Exception { .collect(Collectors.toList()); // Compile a list of outputs from the predecessor nodes based on type - List predecessorOutputs = new ArrayList<>(); - for (String nodeType : predecessorNodeTypes) { - List nodeTypeOutputs = validator.getWorkflowStepValidators().get(nodeType).getOutputs(); - predecessorOutputs.addAll(nodeTypeOutputs); - } + List predecessorOutputs = predecessorNodeTypes.stream() + .map(nodeType -> validator.getWorkflowStepValidators().get(nodeType).getOutputs()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); // Retrieve all the user input data from this node List currentNodeUserInputs = new ArrayList(processNode.input().getContent().keySet()); // Combine both predecessor outputs and current node user inputs - List allInputs = new ArrayList<>(); - allInputs.addAll(predecessorOutputs); - allInputs.addAll(currentNodeUserInputs); + List allInputs = Stream.concat(predecessorOutputs.stream(), currentNodeUserInputs.stream()) + .collect(Collectors.toList()); // Retrieve list of required inputs from the current process node and compare List expectedInputs = validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs(); + if (!allInputs.containsAll(expectedInputs)) { expectedInputs.removeAll(allInputs); throw new IllegalArgumentException("Invalid graph, missing the following required inputs : " + expectedInputs.toString()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 0a6815b82..f06a14962 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -203,7 +203,44 @@ public void testExceptions() throws IOException { assertEquals("Edge destination C does not correspond to a node.", ex.getMessage()); } - public void testValidateGraph() { + public void testSuccessfulGraphValidation() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credentials", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + workflowProcessSorter.validateGraph(sortedProcessNodes); + } + + public void testFailedGraphValidation() { // Create Register Model workflow node with missing connector_id field WorkflowNode registerModel = new WorkflowNode( @@ -214,7 +251,7 @@ public void testValidateGraph() { ); WorkflowNode deployModel = new WorkflowNode( "workflow_step_2", - RegisterModelStep.NAME, + DeployModelStep.NAME, Map.ofEntries(Map.entry("workflow_step_1", "model_id")), Map.of() ); From 67e993f265fb81cfe31de9a5b6f83d27557425dd Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 00:19:13 +0000 Subject: [PATCH 06/19] Adding javadocs Signed-off-by: Joshua Palis --- .../opensearch/flowframework/indices/FlowFrameworkIndex.java | 3 +++ .../opensearch/flowframework/model/ProvisioningProgress.java | 3 +++ src/main/java/org/opensearch/flowframework/model/State.java | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index e23b9ddf0..4b005e45d 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -29,6 +29,9 @@ public enum FlowFrameworkIndex { ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getGlobalContextMappings), GLOBAL_CONTEXT_INDEX_VERSION ), + /** + * Workflow State Index + */ WORKFLOW_STATE( WORKFLOW_STATE_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings), diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java index 1aefecb4b..d5a2a5734 100644 --- a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -13,7 +13,10 @@ */ // TODO: transfer this to more detailed array for each step public enum ProvisioningProgress { + /** Not Started State */ NOT_STARTED, + /** In Progress State */ IN_PROGRESS, + /** Done State */ DONE } diff --git a/src/main/java/org/opensearch/flowframework/model/State.java b/src/main/java/org/opensearch/flowframework/model/State.java index 3288ed4ab..bb9540c52 100644 --- a/src/main/java/org/opensearch/flowframework/model/State.java +++ b/src/main/java/org/opensearch/flowframework/model/State.java @@ -12,8 +12,12 @@ * Enum relating to the state of a workflow */ public enum State { + /** Not Started state */ NOT_STARTED, + /** Provisioning state */ PROVISIONING, + /** Failed state */ FAILED, + /** Completed state */ COMPLETED } From ddf946f74a087c776574ef3db3b76de638b0d56f Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 00:31:40 +0000 Subject: [PATCH 07/19] Moving validation prior to updating workflow state to provisioning Signed-off-by: Joshua Palis --- .../transport/ProvisionWorkflowTransportAction.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 550f2cfd2..22ac414e5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -107,6 +107,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + workflowProcessSorter.validateGraph(provisionProcessSequence); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( WORKFLOW_STATE_INDEX, workflowId, @@ -123,11 +128,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) ); - // Sort and validate graph - Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); - List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); - workflowProcessSorter.validateGraph(provisionProcessSequence); - // Respond to rest action then execute provisioning workflow async listener.onResponse(new WorkflowResponse(workflowId)); executeWorkflowAsync(workflowId, provisionProcessSequence); From 63524faa00f89f0c516cd86df1ca21a9c44a338b Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 17:13:46 +0000 Subject: [PATCH 08/19] Addressing PR comments Part 1 Signed-off-by: Joshua Palis --- .../model/WorkflowStepValidator.java | 4 ++-- .../flowframework/model/WorkflowValidator.java | 15 +++------------ .../workflow/WorkflowProcessSorter.java | 4 +++- .../model/WorkflowValidatorTests.java | 4 ++-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index e0c21ce0e..e49d7d68a 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -78,7 +78,7 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept * @return the inputs */ public List getInputs() { - return inputs; + return List.copyOf(inputs); } /** @@ -86,6 +86,6 @@ public List getInputs() { * @return the outputs */ public List getOutputs() { - return outputs; + return List.copyOf(outputs); } } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java index af45ce77f..506b73ab8 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java @@ -10,10 +10,8 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; import java.net.URL; @@ -64,16 +62,9 @@ public static WorkflowValidator parse(XContentParser parser) throws IOException * @throws IOException on failure to read and parse the json file */ public static WorkflowValidator parse(String file) throws IOException { - URL url = WorkflowValidator.class.getClassLoader().getResource(file); String json = Resources.toString(url, Charsets.UTF_8); - XContentParser parser = JsonXContent.jsonXContent.createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - json - ); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - return parse(parser); + return parse(ParseUtils.jsonToParser(json)); } /** @@ -81,7 +72,7 @@ public static WorkflowValidator parse(String file) throws IOException { * @return the map of WorkflowStepValidators */ public Map getWorkflowStepValidators() { - return this.workflowStepValidators; + return Map.copyOf(this.workflowStepValidators); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 233539ef5..745de5921 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -125,7 +125,9 @@ public void validateGraph(List processNodes) throws Exception { .collect(Collectors.toList()); // Retrieve list of required inputs from the current process node and compare - List expectedInputs = validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs(); + List expectedInputs = new ArrayList( + validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs() + ); if (!allInputs.containsAll(expectedInputs)) { expectedInputs.removeAll(allInputs); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 61d5238d4..6c474a11e 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -80,8 +80,8 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { List registeredWorkflowStepTypes = new ArrayList(workflowStepFactory.getStepMap().keySet()); // Check if each registered step has a corresponding validator definition - assertEquals(registeredWorkflowValidatorTypes, registeredWorkflowStepTypes); - + assertTrue(registeredWorkflowStepTypes.containsAll(registeredWorkflowValidatorTypes)); + assertTrue(registeredWorkflowValidatorTypes.containsAll(registeredWorkflowStepTypes)); } } From 56865d208cfa695f75e3135bb0b1878fdfa788ba Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 17:32:09 +0000 Subject: [PATCH 09/19] Addressing PR comments Part 2 : Moving field names to common value class and using constants Signed-off-by: Joshua Palis --- .../flowframework/common/CommonValue.java | 16 +++++++ .../workflow/CreateIndexStep.java | 9 ++-- .../workflow/CreateIngestPipelineStep.java | 48 +++++++++---------- 3 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 32acc9a68..ecce8ec50 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -61,6 +61,22 @@ private CommonValue() {} /** The provision workflow thread pool name */ public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; + /** Index name field */ + public static final String INDEX_NAME = "index_name"; + /** Type field */ + public static final String TYPE = "type"; + /** ID Field */ + public static final String ID = "id"; + /** Pipeline Id field */ + public static final String PIPELINE_ID = "pipeline_id"; + /** Processors field */ + public static final String PROCESSORS = "processors"; + /** Field map field */ + public static final String FIELD_MAP = "field_map"; + /** Input Field Name field */ + public static final String INPUT_FIELD_NAME = "input_field_name"; + /** Output Field Name field */ + public static final String OUTPUT_FIELD_NAME = "output_field_name"; /** Model Id field */ public static final String MODEL_ID = "model_id"; /** Function Name field */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index b1ecb0321..5fe47b2b0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -25,6 +25,9 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import static org.opensearch.flowframework.common.CommonValue.INDEX_NAME; +import static org.opensearch.flowframework.common.CommonValue.TYPE; + /** * Step to create an index */ @@ -58,7 +61,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of("index_name", createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()))); } @Override @@ -74,8 +77,8 @@ public void onFailure(Exception e) { for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); - index = (String) content.get("index_name"); - type = (String) content.get("type"); + index = (String) content.get(INDEX_NAME); + type = (String) content.get(TYPE); if (index != null && type != null && settings != null) { break; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 2c4f20b9b..b8cc83651 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -26,6 +26,16 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.FIELD_MAP; +import static org.opensearch.flowframework.common.CommonValue.ID; +import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; +import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; +import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; +import static org.opensearch.flowframework.common.CommonValue.TYPE; + /** * Workflow step to create an ingest pipeline */ @@ -36,18 +46,6 @@ public class CreateIngestPipelineStep implements WorkflowStep { /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ static final String NAME = "create_ingest_pipeline"; - // Common pipeline configuration fields - private static final String PIPELINE_ID_FIELD = "id"; - private static final String DESCRIPTION_FIELD = "description"; - private static final String PROCESSORS_FIELD = "processors"; - private static final String TYPE_FIELD = "type"; - - // Temporary text embedding processor fields - private static final String FIELD_MAP = "field_map"; - private static final String MODEL_ID_FIELD = "model_id"; - private static final String INPUT_FIELD = "input_field_name"; - private static final String OUTPUT_FIELD = "output_field_name"; - // Client to store a pipeline in the cluster state private final ClusterAdminClient clusterAdminClient; @@ -80,23 +78,23 @@ public CompletableFuture execute(List data) { for (Entry entry : content.entrySet()) { switch (entry.getKey()) { - case PIPELINE_ID_FIELD: - pipelineId = (String) content.get(PIPELINE_ID_FIELD); + case ID: + pipelineId = (String) content.get(ID); break; case DESCRIPTION_FIELD: description = (String) content.get(DESCRIPTION_FIELD); break; - case TYPE_FIELD: - type = (String) content.get(TYPE_FIELD); + case TYPE: + type = (String) content.get(TYPE); break; - case MODEL_ID_FIELD: - modelId = (String) content.get(MODEL_ID_FIELD); + case MODEL_ID: + modelId = (String) content.get(MODEL_ID); break; - case INPUT_FIELD: - inputFieldName = (String) content.get(INPUT_FIELD); + case INPUT_FIELD_NAME: + inputFieldName = (String) content.get(INPUT_FIELD_NAME); break; - case OUTPUT_FIELD: - outputFieldName = (String) content.get(OUTPUT_FIELD); + case OUTPUT_FIELD_NAME: + outputFieldName = (String) content.get(OUTPUT_FIELD_NAME); break; default: break; @@ -127,7 +125,7 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipeline_id", putPipelineRequest.getId()))); + createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()))); // TODO : Use node client to index response data to global context (pending global context index implementation) @@ -178,10 +176,10 @@ private XContentBuilder buildIngestPipelineRequestContent( return XContentFactory.jsonBuilder() .startObject() .field(DESCRIPTION_FIELD, description) - .startArray(PROCESSORS_FIELD) + .startArray(PROCESSORS) .startObject() .startObject(type) - .field(MODEL_ID_FIELD, modelId) + .field(MODEL_ID, modelId) .startObject(FIELD_MAP) .field(inputFieldName, outputFieldName) .endObject() From eb428ded27e739ebc43b0d23e1d56c30c03f9ef6 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 19:32:48 +0000 Subject: [PATCH 10/19] Adding definition for noop workflow step Signed-off-by: Joshua Palis --- src/main/resources/mappings/workflow-steps.json | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index f836067b2..23eb81c00 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -1,4 +1,8 @@ { + "noop": { + "inputs":[], + "outputs":[] + }, "create_index": { "inputs":[ "index_name", From 8ba9d0b269ad65779d4c27e72a9ef924b5a0550c Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 20:01:57 +0000 Subject: [PATCH 11/19] Addressing PR comments Part 3 Signed-off-by: Joshua Palis --- .../opensearch/flowframework/workflow/WorkflowStepFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index ece942a04..c30bdf87c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -63,6 +63,6 @@ public WorkflowStep createStep(String type) { * @return the step map */ public Map getStepMap() { - return this.stepMap; + return Map.copyOf(this.stepMap); } } From 81636551ab916b015128d2653d3d34952b7b0398 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 31 Oct 2023 22:18:01 +0000 Subject: [PATCH 12/19] Modifies rest actions to throw flow framework exceptions, transport actions to create flow framework exceptions Signed-off-by: Joshua Palis --- .../exception/FlowFrameworkException.java | 11 ++++- .../indices/FlowFrameworkIndicesHandler.java | 29 ++++++----- .../flowframework/model/Template.java | 2 +- .../rest/RestCreateWorkflowAction.java | 37 +++++++++++--- .../rest/RestProvisionWorkflowAction.java | 49 +++++++++++++------ .../CreateWorkflowTransportAction.java | 26 ++++++++-- .../ProvisionWorkflowTransportAction.java | 4 +- .../workflow/WorkflowProcessSorter.java | 22 ++++++--- .../rest/RestCreateWorkflowActionTests.java | 12 +++-- .../RestProvisionWorkflowActionTests.java | 31 ++++++------ .../workflow/WorkflowProcessSorterTests.java | 27 ++++++---- 11 files changed, 173 insertions(+), 77 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java index f3cb55950..52dad9f44 100644 --- a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java +++ b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java @@ -9,11 +9,15 @@ package org.opensearch.flowframework.exception; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; /** * Representation of Flow Framework Exceptions */ -public class FlowFrameworkException extends RuntimeException { +public class FlowFrameworkException extends RuntimeException implements ToXContentObject { private static final long serialVersionUID = 1L; @@ -60,4 +64,9 @@ public FlowFrameworkException(String message, Throwable cause, RestStatus restSt public RestStatus getRestStatus() { return restStatus; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("error", "Request failed with exception: [" + this.getMessage() + "]").endObject(); + } } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 04a3fac5b..6456885de 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -29,6 +29,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -148,7 +149,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, e -> { logger.error("Failed to create index " + indexName, e); - internalListener.onFailure(e); + internalListener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); }); CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); client.admin().indices().create(request, actionListener); @@ -182,7 +183,9 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, exception -> { logger.error("Failed to update index setting for: " + indexName, exception); - internalListener.onFailure(exception); + internalListener.onFailure( + new FlowFrameworkException(exception.getMessage(), INTERNAL_SERVER_ERROR) + ); })); } else { internalListener.onFailure( @@ -191,7 +194,9 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, exception -> { logger.error("Failed to update index " + indexName, exception); - internalListener.onFailure(exception); + internalListener.onFailure( + new FlowFrameworkException(exception.getMessage(), INTERNAL_SERVER_ERROR) + ); }) ); } else { @@ -201,7 +206,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, e -> { logger.error("Failed to update index mapping", e); - internalListener.onFailure(e); + internalListener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); })); } else { // No need to update index if it's already updated. @@ -210,7 +215,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } } catch (Exception e) { logger.error("Failed to init index " + indexName, e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); } } @@ -273,7 +278,7 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { logger.error("Failed to index global_context index"); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); } }, e -> { logger.error("Failed to create global_context index", e); @@ -311,12 +316,12 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to put state index document", e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); } }, e -> { logger.error("Failed to create global_context index", e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); })); } @@ -332,7 +337,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, + documentId + ", global_context index does not exist."; logger.error(exceptionMessage); - listener.onFailure(new Exception(exceptionMessage)); + listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); } else { IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); try ( @@ -344,7 +349,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); } } } @@ -365,7 +370,7 @@ public void updateFlowFrameworkSystemIndexDoc( if (!doesIndexExist(indexName)) { String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; logger.error(exceptionMessage); - listener.onFailure(new Exception(exceptionMessage)); + listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); @@ -377,7 +382,7 @@ public void updateFlowFrameworkSystemIndexDoc( client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); } } } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a05c374d8..fbafb8fd0 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -177,7 +177,7 @@ public static Template parse(XContentParser parser) throws IOException { } } if (name == null) { - throw new IOException("An template object requires a name."); + throw new IOException("A template object requires a name."); } return new Template(name, description, useCase, templateVersion, compatibilityVersion, workflows, user); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index ace440f75..4717adfc6 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -9,13 +9,20 @@ package org.opensearch.flowframework.rest; import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; import java.io.IOException; import java.util.List; @@ -29,6 +36,7 @@ */ public class RestCreateWorkflowAction extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; /** @@ -53,11 +61,28 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - - String workflowId = request.param(WORKFLOW_ID); - Template template = Template.parse(request.content().utf8ToString()); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); - return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + try { + String workflowId = request.param(WORKFLOW_ID); + Template template = Template.parse(request.content().utf8ToString()); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); + return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.CREATED, builder)); + }, exception -> { + try { + FlowFrameworkException ex = (FlowFrameworkException) exception; + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back create workflow exception", e); + } + })); + } catch (Exception e) { + FlowFrameworkException ex = new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST); + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 89471ee00..1bd07eaf0 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -9,14 +9,19 @@ package org.opensearch.flowframework.rest; import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; import java.io.IOException; import java.util.List; @@ -30,6 +35,8 @@ */ public class RestProvisionWorkflowAction extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestProvisionWorkflowAction.class); + private static final String PROVISION_WORKFLOW_ACTION = "provision_workflow_action"; /** @@ -52,21 +59,35 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - - // Validate content - if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); - } - - // Validate params String workflowId = request.param(WORKFLOW_ID); - if (workflowId == null) { - throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + try { + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + // Validate params + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + // Create request and provision + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = (FlowFrameworkException) exception; + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + } + })); + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); } - - // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index c0baccc21..232c2f126 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -83,12 +83,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to save workflow state : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + } }) ); }, exception -> { logger.error("Failed to save use case template : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + })); } else { // Update existing entry, full document replacement @@ -105,12 +114,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to update workflow state : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } }) ); }, exception -> { logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + }) ); } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 22ac414e5..91beefdb0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -133,9 +133,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - if (exception instanceof IllegalArgumentException) { + if (exception instanceof FlowFrameworkException) { logger.error("Workflow validation failed for workflow : " + workflowId); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + listener.onFailure(exception); } else { logger.error("Failed to retrieve template from global context.", exception); listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 745de5921..10a038cbb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; @@ -131,7 +133,10 @@ public void validateGraph(List processNodes) throws Exception { if (!allInputs.containsAll(expectedInputs)) { expectedInputs.removeAll(allInputs); - throw new IllegalArgumentException("Invalid graph, missing the following required inputs : " + expectedInputs.toString()); + throw new FlowFrameworkException( + "Invalid graph, missing the following required inputs : " + expectedInputs.toString(), + RestStatus.BAD_REQUEST + ); } } @@ -142,8 +147,9 @@ private TimeValue parseTimeout(WorkflowNode node) { String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD); TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); if (timeValue.millis() < 0) { - throw new IllegalArgumentException( - "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive" + throw new FlowFrameworkException( + "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive", + RestStatus.BAD_REQUEST ); } return timeValue; @@ -155,14 +161,14 @@ private static List topologicalSort(List workflowNod for (WorkflowEdge edge : workflowEdges) { String source = edge.source(); if (!nodeIds.contains(source)) { - throw new IllegalArgumentException("Edge source " + source + " does not correspond to a node."); + throw new FlowFrameworkException("Edge source " + source + " does not correspond to a node.", RestStatus.BAD_REQUEST); } String dest = edge.destination(); if (!nodeIds.contains(dest)) { - throw new IllegalArgumentException("Edge destination " + dest + " does not correspond to a node."); + throw new FlowFrameworkException("Edge destination " + dest + " does not correspond to a node.", RestStatus.BAD_REQUEST); } if (source.equals(dest)) { - throw new IllegalArgumentException("Edge connects node " + source + " to itself."); + throw new FlowFrameworkException("Edge connects node " + source + " to itself.", RestStatus.BAD_REQUEST); } } @@ -185,7 +191,7 @@ private static List topologicalSort(List workflowNod Queue sourceNodes = new ArrayDeque<>(); workflowNodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); if (sourceNodes.isEmpty()) { - throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + throw new FlowFrameworkException("No start node detected: all nodes have a predecessor.", RestStatus.BAD_REQUEST); } logger.debug("Start node(s): {}", sourceNodes); @@ -208,7 +214,7 @@ private static List topologicalSort(List workflowNod } } if (!graph.isEmpty()) { - throw new IllegalArgumentException("Cycle detected: " + graph); + throw new FlowFrameworkException("Cycle detected: " + graph, RestStatus.BAD_REQUEST); } logger.debug("Execution sequence: {}", sortedNodes); return sortedNodes; diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index d897c6756..3daaa4536 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -11,6 +11,7 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.model.Template; @@ -20,9 +21,9 @@ import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -84,13 +85,14 @@ public void testRestCreateWorkflowActionRoutes() { } - public void testInvalidCreateWorkflowRequest() throws IOException { + public void testInvalidCreateWorkflowRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) .withContent(new BytesArray(invalidTemplate), MediaTypeRegistry.JSON) .build(); - - IOException ex = expectThrows(IOException.class, () -> { createWorkflowRestAction.prepareRequest(request, nodeClient); }); - assertEquals("Unable to parse field [invalid] in a template object.", ex.getMessage()); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Unable to parse field [invalid] in a template object.")); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index a44817cec..4d9ef22e4 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -12,13 +12,12 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; -import java.io.IOException; import java.util.List; import java.util.Locale; @@ -51,31 +50,35 @@ public void testRestProvisiionWorkflowActionRoutes() { assertEquals(this.provisionWorkflowPath, routes.get(0).getPath()); } - public void testNullWorkflowIdAndTemplate() throws IOException { + public void testNullWorkflowId() throws Exception { - // Request with no content or params + // Request with no params RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) .build(); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { - provisionWorkflowRestAction.prepareRequest(request, nodeClient); - }); - assertEquals("workflow_id cannot be null", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() throws IOException { + public void testInvalidRequestWithContent() { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { - provisionWorkflowRestAction.prepareRequest(request, nodeClient); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); }); - assertEquals("Invalid request format", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_provision] does not support having a body", + ex.getMessage() + ); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 9f629ff9e..5ac4af17e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -141,31 +141,35 @@ public void testOrdering() throws IOException { public void testCycles() { Exception ex; - ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); + ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); assertEquals("Edge connects node A to itself.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) ); assertEquals("Edge connects node B to itself.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) ); assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) ); assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); assertTrue(ex.getMessage().contains("B->C")); assertTrue(ex.getMessage().contains("C->B")); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse( workflow( List.of(node("A"), node("B"), node("C"), node("D")), @@ -177,6 +181,7 @@ public void testCycles() { assertTrue(ex.getMessage().contains("B->C")); assertTrue(ex.getMessage().contains("C->D")); assertTrue(ex.getMessage().contains("D->B")); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testNoEdges() throws IOException { @@ -196,13 +201,15 @@ public void testNoEdges() throws IOException { public void testExceptions() throws IOException { Exception ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("C", "B")))) ); assertEquals("Edge source C does not correspond to a node.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); - ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); + ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); assertEquals("Edge destination C does not correspond to a node.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( FlowFrameworkException.class, @@ -268,11 +275,11 @@ public void testFailedGraphValidation() { Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); - IllegalArgumentException ex = expectThrows( - IllegalArgumentException.class, + FlowFrameworkException ex = expectThrows( + FlowFrameworkException.class, () -> workflowProcessSorter.validateGraph(sortedProcessNodes) ); assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); - + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } } From 98ab9801b479a16a24b4334994a59fef06ca0c84 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 1 Nov 2023 20:12:48 +0000 Subject: [PATCH 13/19] Fixing credentials field in workflow-step json Signed-off-by: Joshua Palis --- src/main/resources/mappings/workflow-steps.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 23eb81c00..241a8ecbc 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -32,7 +32,7 @@ "version", "protocol", "parameters", - "credentials", + "credential", "actions" ], "outputs":[ From b3336b2bc35df4397317015a3128675c76e43b3b Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 1 Nov 2023 21:34:48 +0000 Subject: [PATCH 14/19] Fixing test Signed-off-by: Joshua Palis --- .../flowframework/workflow/WorkflowProcessSorterTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 5ac4af17e..65fccbb7e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -230,7 +230,7 @@ public void testSuccessfulGraphValidation() throws Exception { Map.entry("version", ""), Map.entry("protocol", ""), Map.entry("parameters", ""), - Map.entry("credentials", ""), + Map.entry("credential", ""), Map.entry("actions", "") ) ); From 515507ac5c9fca8aac6f0694589e73ac1c65626e Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 2 Nov 2023 19:01:05 +0000 Subject: [PATCH 15/19] Using ExceptionsHelper.status() to determine the rest status code based on exceptions thrown by the transport client Signed-off-by: Joshua Palis --- .../indices/FlowFrameworkIndicesHandler.java | 24 +++++++++++-------- .../CreateWorkflowTransportAction.java | 7 +++--- .../ProvisionWorkflowTransportAction.java | 9 +++---- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 6456885de..077d399eb 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -12,6 +12,7 @@ import com.google.common.io.Resources; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; @@ -149,7 +150,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, e -> { logger.error("Failed to create index " + indexName, e); - internalListener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + internalListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); }); CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); client.admin().indices().create(request, actionListener); @@ -184,7 +185,10 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe }, exception -> { logger.error("Failed to update index setting for: " + indexName, exception); internalListener.onFailure( - new FlowFrameworkException(exception.getMessage(), INTERNAL_SERVER_ERROR) + new FlowFrameworkException( + exception.getMessage(), + ExceptionsHelper.status(exception) + ) ); })); } else { @@ -195,7 +199,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe }, exception -> { logger.error("Failed to update index " + indexName, exception); internalListener.onFailure( - new FlowFrameworkException(exception.getMessage(), INTERNAL_SERVER_ERROR) + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) ); @@ -206,7 +210,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, e -> { logger.error("Failed to update index mapping", e); - internalListener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + internalListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); })); } else { // No need to update index if it's already updated. @@ -215,7 +219,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } } catch (Exception e) { logger.error("Failed to init index " + indexName, e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @@ -278,7 +282,7 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { logger.error("Failed to index global_context index"); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { logger.error("Failed to create global_context index", e); @@ -316,12 +320,12 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to put state index document", e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { logger.error("Failed to create global_context index", e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); })); } @@ -349,7 +353,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } } @@ -382,7 +386,7 @@ public void updateFlowFrameworkSystemIndexDoc( client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(new FlowFrameworkException(e.getMessage(), INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 232c2f126..7e9ecbc51 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -11,6 +11,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -95,7 +96,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowS try { threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); } catch (Exception exception) { - provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } } @@ -206,7 +207,7 @@ private void executeWorkflow(List workflowSequence, ActionListener< } catch (IllegalArgumentException e) { workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); } catch (Exception ex) { - workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), ExceptionsHelper.status(ex))); } } From eefe60227f7be6ba1b9d583d198935da49ece47f Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 2 Nov 2023 21:44:40 +0000 Subject: [PATCH 16/19] Adding dryrun param to create workflow API, allows for validation before saving Signed-off-by: Joshua Palis --- .../flowframework/common/CommonValue.java | 2 ++ .../rest/RestCreateWorkflowAction.java | 7 +++- .../CreateWorkflowTransportAction.java | 32 +++++++++++++++++++ .../transport/WorkflowRequest.java | 27 +++++++++++++++- .../CreateWorkflowTransportActionTests.java | 2 ++ 5 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ecce8ec50..5a849fd89 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -53,6 +53,8 @@ private CommonValue() {} public static final String WORKFLOW_URI = FLOW_FRAMEWORK_BASE_URI + "/workflow"; /** Field name for workflow Id, the document Id of the indexed use case template */ public static final String WORKFLOW_ID = "workflow_id"; + /** Field name for dry run, the flag to indicate if validation is necessary */ + public static final String DRY_RUN = "dryrun"; /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 4717adfc6..b5400e247 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Locale; +import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; @@ -62,9 +63,13 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { try { + String workflowId = request.param(WORKFLOW_ID); Template template = Template.parse(request.content().utf8ToString()); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); + boolean dryRun = request.paramAsBoolean(DRY_RUN, false); + + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun); + return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.CREATED, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 7e9ecbc51..a6b809fc8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -24,9 +24,14 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.List; + import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @@ -39,6 +44,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction listener) { + User user = getUserContext(client); Template templateWithUser = new Template( request.getTemplate().name(), @@ -73,6 +83,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { @@ -135,4 +160,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener sortedNodes = workflowProcessSorter.sortProcessNodes(workflow); + workflowProcessSorter.validateGraph(sortedNodes); + } + } + } diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 0b105552f..2d2046329 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -32,15 +32,30 @@ public class WorkflowRequest extends ActionRequest { */ @Nullable private Template template; + /** + * Validation flag + */ + private boolean dryRun; /** - * Instantiates a new WorkflowRequest + * Instantiates a new WorkflowRequest and defaults dry run to false * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { + this(workflowId, template, false); + } + + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param dryRun flag to indicate if validation is necessary + */ + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, boolean dryRun) { this.workflowId = workflowId; this.template = template; + this.dryRun = dryRun; } /** @@ -53,6 +68,7 @@ public WorkflowRequest(StreamInput in) throws IOException { this.workflowId = in.readOptionalString(); String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); + this.dryRun = in.readBoolean(); } /** @@ -73,11 +89,20 @@ public Template getTemplate() { return this.template; } + /** + * Gets the dry run validation flag + * @return the dry run boolean + */ + public boolean isDryRun() { + return this.dryRun; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); + out.writeBoolean(dryRun); } @Override diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index b6f7bea2d..ef28166d0 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -56,6 +57,7 @@ public void setUp() throws Exception { this.createWorkflowTransportAction = new CreateWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), + mock(WorkflowProcessSorter.class), flowFrameworkIndicesHandler, client ); From 3021b3d9cbd3077e8ff98caf5702ae1e5c67658b Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 2 Nov 2023 22:11:40 +0000 Subject: [PATCH 17/19] concatenating log message with exception message on failure Signed-off-by: Joshua Palis --- .../indices/FlowFrameworkIndicesHandler.java | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 077d399eb..1b0f7c9d7 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -183,10 +183,11 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe ); } }, exception -> { - logger.error("Failed to update index setting for: " + indexName, exception); + String errorMessage = "Failed to update index setting for: " + indexName; + logger.error(errorMessage, exception); internalListener.onFailure( new FlowFrameworkException( - exception.getMessage(), + errorMessage + " : " + exception.getMessage(), ExceptionsHelper.status(exception) ) ); @@ -197,9 +198,13 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe ); } }, exception -> { - logger.error("Failed to update index " + indexName, exception); + String errorMessage = "Failed to update index " + indexName; + logger.error(errorMessage, exception); internalListener.onFailure( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + new FlowFrameworkException( + errorMessage + " : " + exception.getMessage(), + ExceptionsHelper.status(exception) + ) ); }) ); @@ -209,8 +214,11 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe internalListener.onResponse(true); } }, e -> { - logger.error("Failed to update index mapping", e); - internalListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to update index mapping"; + logger.error(errorMessage, e); + internalListener.onFailure( + new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e)) + ); })); } else { // No need to update index if it's already updated. @@ -218,8 +226,9 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } } } catch (Exception e) { - logger.error("Failed to init index " + indexName, e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to init index " + indexName; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } @@ -281,8 +290,9 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { - logger.error("Failed to index global_context index"); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to index global_context index"; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { logger.error("Failed to create global_context index", e); @@ -319,13 +329,15 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL request.id(workflowId); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to put state index document", e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to put state index document"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { - logger.error("Failed to create global_context index", e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to create global_context index"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); })); } @@ -352,8 +364,9 @@ public void updateTemplateInGlobalContext(String documentId, Template template, .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to update global_context entry : " + documentId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } } @@ -385,8 +398,9 @@ public void updateFlowFrameworkSystemIndexDoc( // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + String errorMessage = "Failed to update " + indexName + " entry : " + documentId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } } From a72904a0a1bbf6ebec8f3fefadb9696bf1897e3a Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 2 Nov 2023 22:33:40 +0000 Subject: [PATCH 18/19] Adding dry run test Signed-off-by: Joshua Palis --- .../CreateWorkflowTransportActionTests.java | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index ef28166d0..fbec8a034 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -23,6 +23,7 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -44,6 +45,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private CreateWorkflowTransportAction createWorkflowTransportAction; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private WorkflowProcessSorter workflowProcessSorter; private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; @@ -53,15 +55,16 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); + threadPool = mock(ThreadPool.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); this.createWorkflowTransportAction = new CreateWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - mock(WorkflowProcessSorter.class), + workflowProcessSorter, flowFrameworkIndicesHandler, client ); - threadPool = mock(ThreadPool.class); // client = mock(Client.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); // threadContext = mock(ThreadContext.class); @@ -90,6 +93,67 @@ public void setUp() throws Exception { ); } + public void testFailedDryRunValidation() { + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + "create_connector", + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + "register_model", + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + "deploy_model", + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id()); + + Workflow workflow = new Workflow( + Map.of(), + List.of(createConnector, registerModel, deployModel), + List.of(edge1, edge2, cyclicalEdge) + ); + + Template cyclicalTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + TestHelpers.randomUser() + ); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true); + + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); + } + public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); From 6b364c6611a5b66eca48ce6b8145b7b8f6598117 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 15:52:42 +0000 Subject: [PATCH 19/19] Simplifying FlowFrameworkException::toXContent Signed-off-by: Joshua Palis --- .../flowframework/exception/FlowFrameworkException.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java index 52dad9f44..7e8aefc15 100644 --- a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java +++ b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java @@ -67,6 +67,6 @@ public RestStatus getRestStatus() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject().field("error", "Request failed with exception: [" + this.getMessage() + "]").endObject(); + return builder.startObject().field("error", this.getMessage()).endObject(); } }