Skip to content

Commit

Permalink
batch ingest API rest and transport actions
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Aug 22, 2024
1 parent 2a33c65 commit 1907a48
Show file tree
Hide file tree
Showing 15 changed files with 836 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ public enum MLTaskType {
@Deprecated
LOAD_MODEL,
REGISTER_MODEL,
DEPLOY_MODEL
DEPLOY_MODEL,
BATCH_INGEST
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.ml.common.transport.batch;

import org.opensearch.action.ActionType;

public class MLBatchIngestionAction extends ActionType<MLBatchIngestionResponse> {
public static MLBatchIngestionAction INSTANCE = new MLBatchIngestionAction();
public static final String NAME = "cluster:admin/opensearch/ml/batch_ingestion";

private MLBatchIngestionAction() {
super(NAME, MLBatchIngestionResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.batch;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getOrderedMap;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.Getter;

/**
* ML batch ingestion data: index, field mapping and input and out files.
*/
public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String TEXT_EMBEDDING_FIELD_MAP_FIELD = "text_embedding_field_map";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
@Getter
private String indexName;
@Getter
private Map<String, String> fieldMapping;
@Getter
private Map<String, String> dataSources;
@Getter
private Map<String, String> credential;

@Builder(toBuilder = true)
public MLBatchIngestionInput(
String indexName,
Map<String, String> fieldMapping,
Map<String, String> dataSources,
Map<String, String> credential
) {
this.indexName = indexName;
this.fieldMapping = fieldMapping;
this.dataSources = dataSources;
this.credential = credential;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
String indexName = null;
Map<String, String> fieldMapping = null;
Map<String, String> dataSources = null;
Map<String, String> credential = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case INDEX_NAME_FIELD:
indexName = parser.text();
break;
case TEXT_EMBEDDING_FIELD_MAP_FIELD:
fieldMapping = getOrderedMap(parser.mapOrdered());
break;
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
case DATA_SOURCE_FIELD:
dataSources = parser.mapStrings();
break;
default:
parser.skipChildren();
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (indexName != null) {
builder.field(INDEX_NAME_FIELD, indexName);
}
if (fieldMapping != null) {
builder.field(TEXT_EMBEDDING_FIELD_MAP_FIELD, fieldMapping);
}
if (dataSources != null) {
builder.field(DATA_SOURCE_FIELD, dataSources);
}
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeOptionalString(indexName);
if (fieldMapping != null) {
output.writeBoolean(true);
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeString);
} else {
output.writeBoolean(false);
}

if (dataSources != null) {
output.writeBoolean(true);
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeString);
} else {
output.writeBoolean(false);
}

if (credential != null) {
output.writeBoolean(true);
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
} else {
output.writeBoolean(false);
}
}

public MLBatchIngestionInput(StreamInput input) throws IOException {
indexName = input.readOptionalString();
if (input.readBoolean()) {
fieldMapping = input.readMap(s -> s.readString(), s -> s.readString());
}
if (input.readBoolean()) {
dataSources = input.readMap(s -> s.readString(), s -> s.readString());
}
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.opensearch.ml.common.transport.batch;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLBatchIngestionRequest extends ActionRequest {

private MLBatchIngestionInput mlBatchIngestionInput;

@Builder
public MLBatchIngestionRequest(MLBatchIngestionInput mlBatchIngestionInput) {
this.mlBatchIngestionInput = mlBatchIngestionInput;
}

public MLBatchIngestionRequest(StreamInput in) throws IOException {
super(in);
this.mlBatchIngestionInput = new MLBatchIngestionInput(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlBatchIngestionInput == null) {
exception = addValidationError("ML batch ingestion input can't be null", exception);
}
if (mlBatchIngestionInput.getCredential() == null) {
exception = addValidationError("ML batch ingestion credentials can't be null", exception);
}
if (mlBatchIngestionInput.getDataSources() == null) {
exception = addValidationError("ML batch ingestion data sources can't be null", exception);
}

return exception;
}

public static MLBatchIngestionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLBatchIngestionRequest) {
return (MLBatchIngestionRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLBatchIngestionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLBatchIngestionRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.opensearch.ml.common.transport.batch;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

public class MLBatchIngestionResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String TASK_TYPE_FIELD = "task_type";
public static final String STATUS_FIELD = "status";

private String taskId;
private MLTaskType taskType;
private String status;

public MLBatchIngestionResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.taskType = in.readEnum(MLTaskType.class);
this.status = in.readString();
}

public MLBatchIngestionResponse(String taskId, MLTaskType mlTaskType, String status) {
this.taskId = taskId;
this.taskType = mlTaskType;
this.status = status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeEnum(taskType);
out.writeString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
if (taskType != null) {
builder.field(TASK_TYPE_FIELD, taskType);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
}

public static MLBatchIngestionResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLBatchIngestionResponse) {
return (MLBatchIngestionResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLBatchIngestionResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLBatchIngestionResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -142,6 +143,27 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
return parameters;
}

@SuppressWarnings("removal")
public static LinkedHashMap<String, String> getOrderedMap(Map<String, ?> parameterObjs) {
LinkedHashMap<String, String> parameters = new LinkedHashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

@SuppressWarnings("removal")
public static String toJson(Object value) {
try {
Expand Down
6 changes: 5 additions & 1 deletion ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.gradle.nativeplatform.platform.internal.DefaultNativePlatform

plugins {
id 'java'
id 'java-library'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.25.0'
Expand Down Expand Up @@ -62,9 +63,12 @@ dependencies {
}

implementation platform('software.amazon.awssdk:bom:2.25.40')
implementation 'software.amazon.awssdk:auth'
api 'software.amazon.awssdk:auth:2.25.40'
implementation 'software.amazon.awssdk:apache-client'
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.25.40'
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.25.40'
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.25.40'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40'
Expand Down
Loading

0 comments on commit 1907a48

Please sign in to comment.