Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added register a group model step #118

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,10 @@ private CommonValue() {}
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
public static final String BACKEND_ROLES_FIELD = "backend_roles";
/** Access mode for the model */
public static final String MODEL_ACCESS_MODE = "access_mode";
/** Add all backend roles */
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles";
}
14 changes: 7 additions & 7 deletions src/main/java/org/opensearch/flowframework/model/Template.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@
public class Template implements ToXContentObject {

/** The template field name for template name */
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
public static final String NAME_FIELD = "name";
private static final String NAME_FIELD = "name";
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
/** The template field name for template description */
public static final String DESCRIPTION_FIELD = "description";
private static final String DESCRIPTION_FIELD = "description";
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
private static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version information */
public static final String VERSION_FIELD = "version";
private static final String VERSION_FIELD = "version";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
private static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
private static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";
private static final String WORKFLOWS_FIELD = "workflows";

private final String name;
private final String description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -85,7 +86,7 @@ public void onFailure(Exception e) {
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;
List<ConnectorAction> actions = new ArrayList<>();

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES;
import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a model group
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public ModelGroupStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {

CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) {
logger.info("Model group registration successful");
registerModelGroupFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
amitgalitz marked this conversation as resolved.
Show resolved Hide resolved
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model group");
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelGroupName = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = false;
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
amitgalitz marked this conversation as resolved.
Show resolved Hide resolved

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelGroupName = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case BACKEND_ROLES_FIELD:
backendRoles = (List<String>) content.get(BACKEND_ROLES_FIELD);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
case MODEL_ACCESS_MODE:
modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE);
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES);
default:
break;
}
}
}

if (modelGroupName == null) {
registerModelGroupFuture.completeExceptionally(

Check warning on line 111 in src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java#L111

Added line #L111 was not covered by tests
new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST)
);
}

MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
if (description != null) {
builder.description(description);
}
if (backendRoles != null && backendRoles.size() > 0) {
builder.backendRoles(backendRoles);
}
if (modelAccessMode != null) {
builder.modelAccessMode(modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.isAddAllBackendRoles(isAddAllBackendRoles);
}
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);

return registerModelGroupFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 139 in src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java#L139

Added line #L139 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -34,9 +35,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand All @@ -49,6 +47,10 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);

ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = "https://test.com";
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Expand All @@ -57,7 +59,20 @@ public void setUp() throws Exception {
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry("actions", List.of("actions"))
Map.entry(
"actions",
List.of(
new ConnectorAction(
actionType,
method,
url,
null,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }",
null,
null
)
)
)
)
);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import com.google.common.collect.ImmutableList;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class ModelGroupStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);
inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Map.entry("description", "description"),
Map.entry("backend_roles", ImmutableList.of("role-1")),
Map.entry("access_mode", AccessMode.PUBLIC),
Map.entry("add_all_backend_roles", false)
)
);
}

public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException {
String modelGroupId = "model_group_id";
String status = MLTaskState.CREATED.name();

ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

doAnswer(invocation -> {
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1);
MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData));

verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

assertTrue(future.isDone());
assertEquals(modelGroupId, future.get().getContent().get("model_group_id"));
assertEquals(status, future.get().getContent().get("model_group_status"));

}

public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException {
ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelGroupResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

doAnswer(invocation -> {
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to register model group", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = modelGroupStep.execute(List.of(inputData));

verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to register model group", ex.getCause().getMessage());

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
public class RegisterModelStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLRegisterModelResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand Down