-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
batch ingest API rest and transport actions
Signed-off-by: Xun Zhang <xunzh@amazon.com>
- Loading branch information
1 parent
2a33c65
commit 1907a48
Showing
15 changed files
with
836 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,5 +15,6 @@ public enum MLTaskType { | |
@Deprecated | ||
LOAD_MODEL, | ||
REGISTER_MODEL, | ||
DEPLOY_MODEL | ||
DEPLOY_MODEL, | ||
BATCH_INGEST | ||
} |
13 changes: 13 additions & 0 deletions
13
common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package org.opensearch.ml.common.transport.batch; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
public class MLBatchIngestionAction extends ActionType<MLBatchIngestionResponse> { | ||
public static MLBatchIngestionAction INSTANCE = new MLBatchIngestionAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/batch_ingestion"; | ||
|
||
private MLBatchIngestionAction() { | ||
super(NAME, MLBatchIngestionResponse::new); | ||
} | ||
|
||
} |
145 changes: 145 additions & 0 deletions
145
common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.batch; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.utils.StringUtils.getOrderedMap; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.common.io.stream.Writeable; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
/** | ||
* ML batch ingestion data: index, field mapping and input and out files. | ||
*/ | ||
public class MLBatchIngestionInput implements ToXContentObject, Writeable { | ||
|
||
public static final String INDEX_NAME_FIELD = "index_name"; | ||
public static final String TEXT_EMBEDDING_FIELD_MAP_FIELD = "text_embedding_field_map"; | ||
public static final String DATA_SOURCE_FIELD = "data_source"; | ||
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential"; | ||
@Getter | ||
private String indexName; | ||
@Getter | ||
private Map<String, String> fieldMapping; | ||
@Getter | ||
private Map<String, String> dataSources; | ||
@Getter | ||
private Map<String, String> credential; | ||
|
||
@Builder(toBuilder = true) | ||
public MLBatchIngestionInput( | ||
String indexName, | ||
Map<String, String> fieldMapping, | ||
Map<String, String> dataSources, | ||
Map<String, String> credential | ||
) { | ||
this.indexName = indexName; | ||
this.fieldMapping = fieldMapping; | ||
this.dataSources = dataSources; | ||
this.credential = credential; | ||
} | ||
|
||
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException { | ||
String indexName = null; | ||
Map<String, String> fieldMapping = null; | ||
Map<String, String> dataSources = null; | ||
Map<String, String> credential = new HashMap<>(); | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case INDEX_NAME_FIELD: | ||
indexName = parser.text(); | ||
break; | ||
case TEXT_EMBEDDING_FIELD_MAP_FIELD: | ||
fieldMapping = getOrderedMap(parser.mapOrdered()); | ||
break; | ||
case CONNECTOR_CREDENTIAL_FIELD: | ||
credential = parser.mapStrings(); | ||
break; | ||
case DATA_SOURCE_FIELD: | ||
dataSources = parser.mapStrings(); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (indexName != null) { | ||
builder.field(INDEX_NAME_FIELD, indexName); | ||
} | ||
if (fieldMapping != null) { | ||
builder.field(TEXT_EMBEDDING_FIELD_MAP_FIELD, fieldMapping); | ||
} | ||
if (dataSources != null) { | ||
builder.field(DATA_SOURCE_FIELD, dataSources); | ||
} | ||
if (credential != null) { | ||
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput output) throws IOException { | ||
output.writeOptionalString(indexName); | ||
if (fieldMapping != null) { | ||
output.writeBoolean(true); | ||
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeString); | ||
} else { | ||
output.writeBoolean(false); | ||
} | ||
|
||
if (dataSources != null) { | ||
output.writeBoolean(true); | ||
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeString); | ||
} else { | ||
output.writeBoolean(false); | ||
} | ||
|
||
if (credential != null) { | ||
output.writeBoolean(true); | ||
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString); | ||
} else { | ||
output.writeBoolean(false); | ||
} | ||
} | ||
|
||
public MLBatchIngestionInput(StreamInput input) throws IOException { | ||
indexName = input.readOptionalString(); | ||
if (input.readBoolean()) { | ||
fieldMapping = input.readMap(s -> s.readString(), s -> s.readString()); | ||
} | ||
if (input.readBoolean()) { | ||
dataSources = input.readMap(s -> s.readString(), s -> s.readString()); | ||
} | ||
if (input.readBoolean()) { | ||
credential = input.readMap(s -> s.readString(), s -> s.readString()); | ||
} | ||
} | ||
|
||
} |
70 changes: 70 additions & 0 deletions
70
common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package org.opensearch.ml.common.transport.batch; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLBatchIngestionRequest extends ActionRequest { | ||
|
||
private MLBatchIngestionInput mlBatchIngestionInput; | ||
|
||
@Builder | ||
public MLBatchIngestionRequest(MLBatchIngestionInput mlBatchIngestionInput) { | ||
this.mlBatchIngestionInput = mlBatchIngestionInput; | ||
} | ||
|
||
public MLBatchIngestionRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.mlBatchIngestionInput = new MLBatchIngestionInput(in); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (mlBatchIngestionInput == null) { | ||
exception = addValidationError("ML batch ingestion input can't be null", exception); | ||
} | ||
if (mlBatchIngestionInput.getCredential() == null) { | ||
exception = addValidationError("ML batch ingestion credentials can't be null", exception); | ||
} | ||
if (mlBatchIngestionInput.getDataSources() == null) { | ||
exception = addValidationError("ML batch ingestion data sources can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLBatchIngestionRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLBatchIngestionRequest) { | ||
return (MLBatchIngestionRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLBatchIngestionRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLBatchIngestionRequest", e); | ||
} | ||
|
||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package org.opensearch.ml.common.transport.batch; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContent; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.MLTaskType; | ||
|
||
public class MLBatchIngestionResponse extends ActionResponse implements ToXContentObject { | ||
public static final String TASK_ID_FIELD = "task_id"; | ||
public static final String TASK_TYPE_FIELD = "task_type"; | ||
public static final String STATUS_FIELD = "status"; | ||
|
||
private String taskId; | ||
private MLTaskType taskType; | ||
private String status; | ||
|
||
public MLBatchIngestionResponse(StreamInput in) throws IOException { | ||
super(in); | ||
this.taskId = in.readString(); | ||
this.taskType = in.readEnum(MLTaskType.class); | ||
this.status = in.readString(); | ||
} | ||
|
||
public MLBatchIngestionResponse(String taskId, MLTaskType mlTaskType, String status) { | ||
this.taskId = taskId; | ||
this.taskType = mlTaskType; | ||
this.status = status; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(taskId); | ||
out.writeEnum(taskType); | ||
out.writeString(status); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(TASK_ID_FIELD, taskId); | ||
if (taskType != null) { | ||
builder.field(TASK_TYPE_FIELD, taskType); | ||
} | ||
builder.field(STATUS_FIELD, status); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public static MLBatchIngestionResponse fromActionResponse(ActionResponse actionResponse) { | ||
if (actionResponse instanceof MLBatchIngestionResponse) { | ||
return (MLBatchIngestionResponse) actionResponse; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionResponse.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLBatchIngestionResponse(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionResponse into MLBatchIngestionResponse", e); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.