Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Create Connector Step #107

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/main/java/demo/Demo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);

ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
Expand Down Expand Up @@ -76,7 +77,8 @@ public Collection<Object> createComponents(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

// TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep
Expand Down
34 changes: 0 additions & 34 deletions src/main/java/org/opensearch/flowframework/client/MLClient.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ private CommonValue() {}
public static final String MODEL_ID = "model_id";
/** Function Name field */
public static final String FUNCTION_NAME = "function_name";
/** Model Name field */
public static final String MODEL_NAME = "name";
/** Name field */
public static final String NAME_FIELD = "name";
/** Model Version field */
public static final String MODEL_VERSION = "model_version";
/** Model Group Id field */
Expand All @@ -62,4 +62,14 @@ private CommonValue() {}
public static final String MODEL_FORMAT = "model_format";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
/** Version field */
public static final String VERSION_FIELD = "version";
/** Connector protocol field */
public static final String PROTOCOL_FIELD = "protocol";
/** Connector parameters field */
public static final String PARAMETERS_FIELD = "parameters";
/** Connector credentials field */
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to create a connector for a remote model
*/
public class CreateConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "create_connector";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public CreateConnectorStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
CompletableFuture<WorkflowData> createConnectorFuture = new CompletableFuture<>();

ActionListener<MLCreateConnectorResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
logger.info("Created connector successfully");
// TODO Add the response to Global Context
createConnectorFuture.complete(
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())))
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to create connector");
createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String name = null;
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
break;
case PROTOCOL_FIELD:
protocol = (String) content.get(PROTOCOL_FIELD);
break;
case PARAMETERS_FIELD:
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD));
break;

Check warning on line 109 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L108-L109

Added lines #L108 - L109 were not covered by tests
case CREDENTIALS_FIELD:
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD);
break;
case ACTIONS_FIELD:
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD);
break;
}

}
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder()
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
.name(name)
.description(description)
.version(version)
.protocol(protocol)
.parameters(parameters)
.credential(credentials)
.actions(actions)
.build();

mlClient.createConnector(mlInput, actionListener);
} else {
createConnectorFuture.completeExceptionally(

Check warning on line 134 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L134

Added line #L134 was not covered by tests
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

return createConnectorFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 144 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L144

Added line #L144 was not covered by tests
}

private static Map<String, String> getParameterMap(Map<String, String> params) {

Map<String, String> parameters = new HashMap<>();

Check warning on line 149 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L149

Added line #L149 was not covered by tests
for (String key : params.keySet()) {
String value = params.get(key);

Check warning on line 151 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L151

Added line #L151 was not covered by tests
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
parameters.put(key, value);
return null;

Check warning on line 155 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L153-L155

Added lines #L153 - L155 were not covered by tests
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;

Check warning on line 161 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L157-L161

Added lines #L157 - L161 were not covered by tests
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

Expand All @@ -28,24 +29,22 @@
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
private MachineLearningNodeClient mlClient;
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
* @param mlClient client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
public DeployModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
Expand All @@ -57,8 +56,8 @@

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
logger.error("Failed to deploy model");
deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Expand All @@ -70,7 +69,13 @@
break;
}
}
machineLearningNodeClient.deploy(modelId, actionListener);

if (modelId != null) {
mlClient.deploy(modelId, actionListener);
} else {
deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST));

Check warning on line 76 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L76

Added line #L76 was not covered by tests
}

return deployModelFuture;
}

Expand Down
Loading
Loading