Skip to content

Commit

Permalink
offline batch ingestion API actions and data ingesters (#2844)
Browse files Browse the repository at this point in the history
* batch ingest API rest and transport actions

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* add openAI ingester

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* update batch ingestion field mapping interphase and address comments

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* support multiple data sources as ingestion inputs

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* use dedicated thread pool for ingestion

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt authored Sep 4, 2024
1 parent cc402b3 commit 33a7c96
Show file tree
Hide file tree
Showing 31 changed files with 2,314 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ public enum MLTaskType {
@Deprecated
LOAD_MODEL,
REGISTER_MODEL,
DEPLOY_MODEL
DEPLOY_MODEL,
BATCH_INGEST
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.batch;

import org.opensearch.action.ActionType;

public class MLBatchIngestionAction extends ActionType<MLBatchIngestionResponse> {
public static MLBatchIngestionAction INSTANCE = new MLBatchIngestionAction();
public static final String NAME = "cluster:admin/opensearch/ml/batch_ingestion";

private MLBatchIngestionAction() {
super(NAME, MLBatchIngestionResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.batch;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.Getter;

/**
* ML batch ingestion data: index, field mapping and input and out files.
*/
public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
@Getter
private String indexName;
@Getter
private Map<String, Object> fieldMapping;
@Getter
private Map<String, Object> dataSources;
@Getter
private Map<String, String> credential;

@Builder(toBuilder = true)
public MLBatchIngestionInput(
String indexName,
Map<String, Object> fieldMapping,
Map<String, Object> dataSources,
Map<String, String> credential
) {
if (indexName == null) {
throw new IllegalArgumentException(
"The index name for data ingestion is missing. Please provide a valid index name to proceed."
);
}
if (dataSources == null) {
throw new IllegalArgumentException(
"No data sources were provided for ingestion. Please specify at least one valid data source to proceed."
);
}
this.indexName = indexName;
this.fieldMapping = fieldMapping;
this.dataSources = dataSources;
this.credential = credential;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
String indexName = null;
Map<String, Object> fieldMapping = null;
Map<String, Object> dataSources = null;
Map<String, String> credential = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case INDEX_NAME_FIELD:
indexName = parser.text();
break;
case FIELD_MAP_FIELD:
fieldMapping = parser.map();
break;
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
case DATA_SOURCE_FIELD:
dataSources = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (indexName != null) {
builder.field(INDEX_NAME_FIELD, indexName);
}
if (fieldMapping != null) {
builder.field(FIELD_MAP_FIELD, fieldMapping);
}
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
if (dataSources != null) {
builder.field(DATA_SOURCE_FIELD, dataSources);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeOptionalString(indexName);
if (fieldMapping != null) {
output.writeBoolean(true);
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeGenericValue);
} else {
output.writeBoolean(false);
}
if (credential != null) {
output.writeBoolean(true);
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
} else {
output.writeBoolean(false);
}
if (dataSources != null) {
output.writeBoolean(true);
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
} else {
output.writeBoolean(false);
}
}

public MLBatchIngestionInput(StreamInput input) throws IOException {
indexName = input.readOptionalString();
if (input.readBoolean()) {
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
if (input.readBoolean()) {
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.batch;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLBatchIngestionRequest extends ActionRequest {

private MLBatchIngestionInput mlBatchIngestionInput;

@Builder
public MLBatchIngestionRequest(MLBatchIngestionInput mlBatchIngestionInput) {
this.mlBatchIngestionInput = mlBatchIngestionInput;
}

public MLBatchIngestionRequest(StreamInput in) throws IOException {
super(in);
this.mlBatchIngestionInput = new MLBatchIngestionInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlBatchIngestionInput.writeTo(out);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlBatchIngestionInput == null) {
exception = addValidationError("The input for ML batch ingestion cannot be null.", exception);
}
if (mlBatchIngestionInput != null && mlBatchIngestionInput.getCredential() == null) {
exception = addValidationError("The credential for ML batch ingestion cannot be null", exception);
}
if (mlBatchIngestionInput != null && mlBatchIngestionInput.getDataSources() == null) {
exception = addValidationError("The data sources for ML batch ingestion cannot be null", exception);
}

return exception;
}

public static MLBatchIngestionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLBatchIngestionRequest) {
return (MLBatchIngestionRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLBatchIngestionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLBatchIngestionRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.batch;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import lombok.Getter;

@Getter
public class MLBatchIngestionResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String TASK_TYPE_FIELD = "task_type";
public static final String STATUS_FIELD = "status";

private String taskId;
private MLTaskType taskType;
private String status;

public MLBatchIngestionResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.taskType = in.readEnum(MLTaskType.class);
this.status = in.readString();
}

public MLBatchIngestionResponse(String taskId, MLTaskType mlTaskType, String status) {
this.taskId = taskId;
this.taskType = mlTaskType;
this.status = status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeEnum(taskType);
out.writeString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
if (taskType != null) {
builder.field(TASK_TYPE_FIELD, taskType);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
}

public static MLBatchIngestionResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLBatchIngestionResponse) {
return (MLBatchIngestionResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLBatchIngestionResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLBatchIngestionResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,18 @@ public static Map<String, String> parseParameters(Map<String, String> parameters
return parameters;
}

public static String obtainFieldNameFromJsonPath(String jsonPath) {
String[] parts = jsonPath.split("\\.");

// Get the last part which is the field name
return parts[parts.length - 1];
}

public static String getJsonPath(String jsonPathWithSource) {
// Find the index of the first occurrence of "$."
int startIndex = jsonPathWithSource.indexOf("$.");

// Extract the substring from the startIndex to the end of the input string
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ public void parseTextDocsMLInput() throws IOException {
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = builder.toString();
System.out.println(jsonStr);
parseMLInput(jsonStr, 2);
}

Expand Down
Loading

0 comments on commit 33a7c96

Please sign in to comment.