From 33a7c967d1af63d8137ec80514b4d12573a9aa07 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Wed, 4 Sep 2024 10:28:15 -0700 Subject: [PATCH] offline batch ingestion API actions and data ingesters (#2844) * batch ingest API rest and transport actions Signed-off-by: Xun Zhang * add openAI ingester Signed-off-by: Xun Zhang * update batch ingestion field mapping interphase and address comments Signed-off-by: Xun Zhang * support multiple data sources as ingestion inputs Signed-off-by: Xun Zhang * use dedicated thread pool for ingestion Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../org/opensearch/ml/common/MLTaskType.java | 3 +- .../batch/MLBatchIngestionAction.java | 18 ++ .../batch/MLBatchIngestionInput.java | 152 ++++++++++ .../batch/MLBatchIngestionRequest.java | 82 ++++++ .../batch/MLBatchIngestionResponse.java | 81 ++++++ .../ml/common/utils/StringUtils.java | 14 + .../common/input/nlp/TextDocsMLInputTest.java | 1 - .../batch/MLBatchIngestionInputTests.java | 138 +++++++++ .../batch/MLBatchIngestionRequestTests.java | 113 ++++++++ .../batch/MLBatchIngestionResponseTests.java | 83 ++++++ .../ml/common/utils/StringUtilsTest.java | 34 +++ .../memory/index/InteractionsIndexTests.java | 1 - ml-algorithms/build.gradle | 6 +- .../ml/engine/MLEngineClassLoader.java | 21 +- .../ml/engine/annotation/Ingester.java | 17 ++ .../ml/engine/ingest/AbstractIngestion.java | 215 ++++++++++++++ .../ml/engine/ingest/Ingestable.java | 19 ++ .../ml/engine/ingest/OpenAIDataIngestion.java | 131 +++++++++ .../ml/engine/ingest/S3DataIngestion.java | 206 ++++++++++++++ .../MLSdkAsyncHttpResponseHandlerTest.java | 2 - .../engine/ingest/AbstractIngestionTests.java | 262 ++++++++++++++++++ .../engine/ingest/S3DataIngestionTests.java | 62 +++++ .../batch/TransportBatchIngestionAction.java | 178 ++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/RestMLBatchIngestAction.java | 65 +++++ .../ml/rest/RestMLPredictionAction.java | 2 - .../opensearch/ml/utils/RestActionUtils.java | 1 - .../plugin-metadata/plugin-security.policy | 6 + .../TransportBatchIngestionActionTests.java | 254 +++++++++++++++++ .../rest/RestMLBatchIngestionActionTests.java | 126 +++++++++ .../org/opensearch/ml/utils/TestHelper.java | 23 ++ 31 files changed, 2314 insertions(+), 12 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInputTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequestTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponseTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/annotation/Ingester.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/S3DataIngestionTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLBatchIngestAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index db2f67f369..e17b36a4dd 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -15,5 +15,6 @@ public enum MLTaskType { @Deprecated LOAD_MODEL, REGISTER_MODEL, - DEPLOY_MODEL + DEPLOY_MODEL, + BATCH_INGEST } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java new file mode 100644 index 0000000000..3e0d39a692 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import org.opensearch.action.ActionType; + +public class MLBatchIngestionAction extends ActionType { + public static MLBatchIngestionAction INSTANCE = new MLBatchIngestionAction(); + public static final String NAME = "cluster:admin/opensearch/ml/batch_ingestion"; + + private MLBatchIngestionAction() { + super(NAME, MLBatchIngestionResponse::new); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java new file mode 100644 index 0000000000..e7050f0bd2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; + +/** + * ML batch ingestion data: index, field mapping and input and out files. + */ +public class MLBatchIngestionInput implements ToXContentObject, Writeable { + + public static final String INDEX_NAME_FIELD = "index_name"; + public static final String FIELD_MAP_FIELD = "field_map"; + public static final String DATA_SOURCE_FIELD = "data_source"; + public static final String CONNECTOR_CREDENTIAL_FIELD = "credential"; + @Getter + private String indexName; + @Getter + private Map fieldMapping; + @Getter + private Map dataSources; + @Getter + private Map credential; + + @Builder(toBuilder = true) + public MLBatchIngestionInput( + String indexName, + Map fieldMapping, + Map dataSources, + Map credential + ) { + if (indexName == null) { + throw new IllegalArgumentException( + "The index name for data ingestion is missing. Please provide a valid index name to proceed." + ); + } + if (dataSources == null) { + throw new IllegalArgumentException( + "No data sources were provided for ingestion. Please specify at least one valid data source to proceed." + ); + } + this.indexName = indexName; + this.fieldMapping = fieldMapping; + this.dataSources = dataSources; + this.credential = credential; + } + + public static MLBatchIngestionInput parse(XContentParser parser) throws IOException { + String indexName = null; + Map fieldMapping = null; + Map dataSources = null; + Map credential = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case INDEX_NAME_FIELD: + indexName = parser.text(); + break; + case FIELD_MAP_FIELD: + fieldMapping = parser.map(); + break; + case CONNECTOR_CREDENTIAL_FIELD: + credential = parser.mapStrings(); + break; + case DATA_SOURCE_FIELD: + dataSources = parser.map(); + break; + default: + parser.skipChildren(); + break; + } + } + return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (indexName != null) { + builder.field(INDEX_NAME_FIELD, indexName); + } + if (fieldMapping != null) { + builder.field(FIELD_MAP_FIELD, fieldMapping); + } + if (credential != null) { + builder.field(CONNECTOR_CREDENTIAL_FIELD, credential); + } + if (dataSources != null) { + builder.field(DATA_SOURCE_FIELD, dataSources); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeOptionalString(indexName); + if (fieldMapping != null) { + output.writeBoolean(true); + output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeGenericValue); + } else { + output.writeBoolean(false); + } + if (credential != null) { + output.writeBoolean(true); + output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString); + } else { + output.writeBoolean(false); + } + if (dataSources != null) { + output.writeBoolean(true); + output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue); + } else { + output.writeBoolean(false); + } + } + + public MLBatchIngestionInput(StreamInput input) throws IOException { + indexName = input.readOptionalString(); + if (input.readBoolean()) { + fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue()); + } + if (input.readBoolean()) { + credential = input.readMap(s -> s.readString(), s -> s.readString()); + } + if (input.readBoolean()) { + dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue()); + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java new file mode 100644 index 0000000000..c6bf0de6d6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLBatchIngestionRequest extends ActionRequest { + + private MLBatchIngestionInput mlBatchIngestionInput; + + @Builder + public MLBatchIngestionRequest(MLBatchIngestionInput mlBatchIngestionInput) { + this.mlBatchIngestionInput = mlBatchIngestionInput; + } + + public MLBatchIngestionRequest(StreamInput in) throws IOException { + super(in); + this.mlBatchIngestionInput = new MLBatchIngestionInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.mlBatchIngestionInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlBatchIngestionInput == null) { + exception = addValidationError("The input for ML batch ingestion cannot be null.", exception); + } + if (mlBatchIngestionInput != null && mlBatchIngestionInput.getCredential() == null) { + exception = addValidationError("The credential for ML batch ingestion cannot be null", exception); + } + if (mlBatchIngestionInput != null && mlBatchIngestionInput.getDataSources() == null) { + exception = addValidationError("The data sources for ML batch ingestion cannot be null", exception); + } + + return exception; + } + + public static MLBatchIngestionRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLBatchIngestionRequest) { + return (MLBatchIngestionRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLBatchIngestionRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLBatchIngestionRequest", e); + } + + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java new file mode 100644 index 0000000000..42ae6857b9 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLTaskType; + +import lombok.Getter; + +@Getter +public class MLBatchIngestionResponse extends ActionResponse implements ToXContentObject { + public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String STATUS_FIELD = "status"; + + private String taskId; + private MLTaskType taskType; + private String status; + + public MLBatchIngestionResponse(StreamInput in) throws IOException { + super(in); + this.taskId = in.readString(); + this.taskType = in.readEnum(MLTaskType.class); + this.status = in.readString(); + } + + public MLBatchIngestionResponse(String taskId, MLTaskType mlTaskType, String status) { + this.taskId = taskId; + this.taskType = mlTaskType; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(taskId); + out.writeEnum(taskType); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(TASK_ID_FIELD, taskId); + if (taskType != null) { + builder.field(TASK_TYPE_FIELD, taskType); + } + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLBatchIngestionResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLBatchIngestionResponse) { + return (MLBatchIngestionResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLBatchIngestionResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLBatchIngestionResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 57c24c22fd..4bf74de3a9 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -279,4 +279,18 @@ public static Map parseParameters(Map parameters return parameters; } + public static String obtainFieldNameFromJsonPath(String jsonPath) { + String[] parts = jsonPath.split("\\."); + + // Get the last part which is the field name + return parts[parts.length - 1]; + } + + public static String getJsonPath(String jsonPathWithSource) { + // Find the index of the first occurrence of "$." + int startIndex = jsonPathWithSource.indexOf("$."); + + // Extract the substring from the startIndex to the end of the input string + return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index 5631071835..4b0947ba15 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -57,7 +57,6 @@ public void parseTextDocsMLInput() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - System.out.println(jsonStr); parseMLInput(jsonStr, 2); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInputTests.java new file mode 100644 index 0000000000..abfef0c6f9 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInputTests.java @@ -0,0 +1,138 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLBatchIngestionInputTests { + + private MLBatchIngestionInput mlBatchIngestionInput; + + private Map dataSource; + + @Rule + public final ExpectedException exceptionRule = ExpectedException.none(); + + private final String expectedInputStr = "{" + + "\"index_name\":\"test index\"," + + "\"field_map\":{" + + "\"chapter\":\"chapter_embedding\"" + + "}," + + "\"credential\":{" + + "\"region\":\"test region\"" + + "}," + + "\"data_source\":{" + + "\"source\":[\"s3://samplebucket/output/sampleresults.json.out\"]," + + "\"type\":\"s3\"" + + "}" + + "}"; + + @Before + public void setUp() { + dataSource = new HashMap<>(); + dataSource.put("type", "s3"); + dataSource.put("source", Arrays.asList("s3://samplebucket/output/sampleresults.json.out")); + + Map credentials = Map.of("region", "test region"); + Map fieldMapping = Map.of("chapter", "chapter_embedding"); + + mlBatchIngestionInput = MLBatchIngestionInput + .builder() + .indexName("test index") + .credential(credentials) + .fieldMapping(fieldMapping) + .dataSources(dataSource) + .build(); + } + + @Test + public void constructorMLBatchIngestionInput_NullName() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The index name for data ingestion is missing. Please provide a valid index name to proceed."); + + MLBatchIngestionInput.builder().indexName(null).dataSources(dataSource).build(); + } + + @Test + public void constructorMLBatchIngestionInput_NullSource() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule + .expectMessage("No data sources were provided for ingestion. Please specify at least one valid data source to proceed."); + MLBatchIngestionInput.builder().indexName("test index").dataSources(null).build(); + } + + @Test + public void testToXContent_FullFields() throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + mlBatchIngestionInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testParse() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("test index", parsedInput.getIndexName()); + assertEquals("test region", parsedInput.getCredential().get("region")); + assertEquals("chapter_embedding", parsedInput.getFieldMapping().get("chapter")); + assertEquals("s3", parsedInput.getDataSources().get("type")); + }); + } + + private void testParseFromJsonString(String expectedInputString, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputString + ); + parser.nextToken(); + MLBatchIngestionInput parsedInput = MLBatchIngestionInput.parse(parser); + verify.accept(parsedInput); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream( + mlBatchIngestionInput, + parsedInput -> assertEquals(mlBatchIngestionInput.getIndexName(), parsedInput.getIndexName()) + ); + } + + private void readInputStream(MLBatchIngestionInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLBatchIngestionInput parsedInput = new MLBatchIngestionInput(streamInput); + verify.accept(parsedInput); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequestTests.java new file mode 100644 index 0000000000..ccbc0477c8 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequestTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLBatchIngestionRequestTests { + private MLBatchIngestionInput mlBatchIngestionInput; + private MLBatchIngestionRequest mlBatchIngestionRequest; + + @Before + public void setUp() { + mlBatchIngestionInput = MLBatchIngestionInput + .builder() + .indexName("test_index_name") + .credential(Map.of("region", "test region")) + .fieldMapping(Map.of("chapter", "chapter_embedding")) + .dataSources(Map.of("type", "s3")) + .build(); + mlBatchIngestionRequest = MLBatchIngestionRequest.builder().mlBatchIngestionInput(mlBatchIngestionInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlBatchIngestionRequest.writeTo(output); + MLBatchIngestionRequest parsedRequest = new MLBatchIngestionRequest(output.bytes().streamInput()); + assertEquals( + mlBatchIngestionRequest.getMlBatchIngestionInput().getIndexName(), + parsedRequest.getMlBatchIngestionInput().getIndexName() + ); + assertEquals( + mlBatchIngestionRequest.getMlBatchIngestionInput().getCredential(), + parsedRequest.getMlBatchIngestionInput().getCredential() + ); + assertEquals( + mlBatchIngestionRequest.getMlBatchIngestionInput().getFieldMapping(), + parsedRequest.getMlBatchIngestionInput().getFieldMapping() + ); + assertEquals( + mlBatchIngestionRequest.getMlBatchIngestionInput().getDataSources(), + parsedRequest.getMlBatchIngestionInput().getDataSources() + ); + } + + @Test + public void validateSuccess() { + assertNull(mlBatchIngestionRequest.validate()); + } + + @Test + public void validateWithNullInputException() { + MLBatchIngestionRequest mlBatchIngestionRequest1 = MLBatchIngestionRequest.builder().build(); + ActionRequestValidationException exception = mlBatchIngestionRequest1.validate(); + assertEquals("Validation Failed: 1: The input for ML batch ingestion cannot be null.;", exception.getMessage()); + } + + @Test + public void fromActionRequestWithBatchRequestSuccess() { + assertSame(MLBatchIngestionRequest.fromActionRequest(mlBatchIngestionRequest), mlBatchIngestionRequest); + } + + @Test + public void fromActionRequestWithNonRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlBatchIngestionRequest.writeTo(out); + } + }; + MLBatchIngestionRequest result = MLBatchIngestionRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlBatchIngestionRequest); + assertEquals(mlBatchIngestionRequest.getMlBatchIngestionInput().getIndexName(), result.getMlBatchIngestionInput().getIndexName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLBatchIngestionRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponseTests.java new file mode 100644 index 0000000000..b0b61f04e0 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponseTests.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.batch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.TestHelper; + +public class MLBatchIngestionResponseTests { + + MLBatchIngestionResponse mlBatchIngestionResponse; + + @Before + public void setUp() { + mlBatchIngestionResponse = new MLBatchIngestionResponse("testId", MLTaskType.BATCH_INGEST, "Created"); + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlBatchIngestionResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"task_id\":\"testId\",\"task_type\":\"BATCH_INGEST\",\"status\":\"Created\"}", content); + } + + @Test + public void readFromStream() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlBatchIngestionResponse.writeTo(output); + + MLBatchIngestionResponse response2 = new MLBatchIngestionResponse(output.bytes().streamInput()); + assertEquals("testId", response2.getTaskId()); + assertEquals("Created", response2.getStatus()); + } + + @Test + public void fromActionResponseWithMLBatchIngestionResponseSuccess() { + MLBatchIngestionResponse responseFromActionResponse = MLBatchIngestionResponse.fromActionResponse(mlBatchIngestionResponse); + assertSame(mlBatchIngestionResponse, responseFromActionResponse); + assertEquals(mlBatchIngestionResponse.getTaskType(), responseFromActionResponse.getTaskType()); + } + + @Test + public void fromActionResponseSuccess() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlBatchIngestionResponse.writeTo(out); + } + }; + MLBatchIngestionResponse responseFromActionResponse = MLBatchIngestionResponse.fromActionResponse(actionResponse); + assertNotSame(mlBatchIngestionResponse, responseFromActionResponse); + assertEquals(mlBatchIngestionResponse.getTaskType(), responseFromActionResponse.getTaskType()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLBatchIngestionResponse.fromActionResponse(actionResponse); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index a4b1460f39..aed76c5658 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -8,6 +8,8 @@ import static org.junit.Assert.assertEquals; import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; +import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; +import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath; import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import static org.opensearch.ml.common.utils.StringUtils.toJson; @@ -423,4 +425,36 @@ public void testParseParametersNestedMapToString() { "{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" ); } + + @Test + public void testObtainFieldNameFromJsonPath_ValidJsonPath() { + // Test with a typical JSONPath + String jsonPath = "$.response.body.data[*].embedding"; + String fieldName = obtainFieldNameFromJsonPath(jsonPath); + assertEquals("embedding", fieldName); + } + + @Test + public void testObtainFieldNameFromJsonPath_WithPrefix() { + // Test with JSONPath that has a prefix + String jsonPath = "source[1].$.response.body.data[*].embedding"; + String fieldName = obtainFieldNameFromJsonPath(jsonPath); + assertEquals("embedding", fieldName); + } + + @Test + public void testGetJsonPath_ValidJsonPathWithSource() { + // Test with a JSONPath that includes a source prefix + String input = "source[1].$.response.body.data[*].embedding"; + String result = getJsonPath(input); + assertEquals("$.response.body.data[*].embedding", result); + } + + @Test + public void testGetJsonPath_ValidJsonPathWithoutSource() { + // Test with a JSONPath that does not include a source prefix + String input = "$.response.body.data[*].embedding"; + String result = getJsonPath(input); + assertEquals("$.response.body.data[*].embedding", result); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 4da9f9d68e..042a4a3a91 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -750,7 +750,6 @@ public void testGetSg_NoIndex_ThenFail() { interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); - System.out.println(argCaptor.getValue().getMessage()); assert (argCaptor .getValue() .getMessage() diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index e97849b019..5b7d146da7 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -7,6 +7,7 @@ import org.gradle.nativeplatform.platform.internal.DefaultNativePlatform plugins { id 'java' + id 'java-library' id 'jacoco' id "io.freefair.lombok" id 'com.diffplug.spotless' version '6.25.0' @@ -62,9 +63,12 @@ dependencies { } implementation platform('software.amazon.awssdk:bom:2.25.40') - implementation 'software.amazon.awssdk:auth' + api 'software.amazon.awssdk:auth:2.25.40' implementation 'software.amazon.awssdk:apache-client' implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' + implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.25.40' + implementation group: 'software.amazon.awssdk', name: 's3', version: '2.25.40' + implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.25.40' implementation 'com.jayway.jsonpath:json-path:2.9.0' implementation group: 'org.json', name: 'json', version: '20231013' implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java index 9add9a4f9e..7205883e7f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.ml.engine.annotation.Function; +import org.opensearch.ml.engine.annotation.Ingester; import org.reflections.Reflections; @SuppressWarnings("removal") @@ -31,6 +32,7 @@ public class MLEngineClassLoader { */ private static Map, Class> mlAlgoClassMap = new HashMap<>(); private static Map> connectorExecutorMap = new HashMap<>(); + private static Map> ingesterMap = new HashMap<>(); /** * This map contains pre-created thread-safe ML objects. @@ -41,6 +43,7 @@ public class MLEngineClassLoader { try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { loadClassMapping(); + loadIngestClassMapping(); return null; }); } catch (PrivilegedActionException e) { @@ -69,7 +72,7 @@ public static Object deregister(Enum functionName) { return mlObjects.remove(functionName); } - public static void loadClassMapping() { + private static void loadClassMapping() { Reflections reflections = new Reflections("org.opensearch.ml.engine.algorithms"); Set> classes = reflections.getTypesAnnotatedWith(Function.class); @@ -93,6 +96,19 @@ public static void loadClassMapping() { } } + private static void loadIngestClassMapping() { + Reflections reflections = new Reflections("org.opensearch.ml.engine.ingest"); + Set> ingesterClasses = reflections.getTypesAnnotatedWith(Ingester.class); + // Load ingester class + for (Class clazz : ingesterClasses) { + Ingester ingester = clazz.getAnnotation(Ingester.class); + String ingesterSource = ingester.value(); + if (ingesterSource != null) { + ingesterMap.put(ingesterSource, clazz); + } + } + } + @SuppressWarnings("unchecked") public static S initInstance(T type, I in, Class constructorParamClass) { return initInstance(type, in, constructorParamClass, null); @@ -120,6 +136,9 @@ public static S initInstance(T type, I in, Class con if (clazz == null) { clazz = connectorExecutorMap.get(type); } + if (clazz == null) { + clazz = ingesterMap.get(type); + } if (clazz == null) { throw new IllegalArgumentException("Can't find class for type " + type); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/annotation/Ingester.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/annotation/Ingester.java new file mode 100644 index 0000000000..6bacf7c76d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/annotation/Ingester.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface Ingester { + String value(); +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java new file mode 100644 index 0000000000..be61f09e28 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java @@ -0,0 +1,215 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; +import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.common.utils.StringUtils; + +import com.jayway.jsonpath.JsonPath; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class AbstractIngestion implements Ingestable { + public static final String OUTPUT = "output"; + public static final String INPUT = "input"; + public static final String OUTPUT_FIELD_NAMES = "output_names"; + public static final String INPUT_FIELD_NAMES = "input_names"; + public static final String INGEST_FIELDS = "ingest_fields"; + public static final String ID_FIELD = "id_field"; + + private final Client client; + + public AbstractIngestion(Client client) { + this.client = client; + } + + protected ActionListener getBulkResponseListener( + AtomicInteger successfulBatches, + AtomicInteger failedBatches, + CompletableFuture future + ) { + return ActionListener.wrap(bulkResponse -> { + if (bulkResponse.hasFailures()) { + failedBatches.incrementAndGet(); + future.completeExceptionally(new RuntimeException(bulkResponse.buildFailureMessage())); // Mark the future as completed + // with an exception + return; + } + log.debug("Batch Ingestion successfully"); + successfulBatches.incrementAndGet(); + future.complete(null); // Mark the future as completed successfully + }, e -> { + log.error("Failed to Batch Ingestion", e); + failedBatches.incrementAndGet(); + future.completeExceptionally(e); // Mark the future as completed with an exception + }); + } + + protected double calculateSuccessRate(List successRates) { + return successRates + .stream() + .min(Double::compare) + .orElseThrow( + () -> new OpenSearchStatusException( + "Failed to batch ingest data as not success rate is returned", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + + /** + * Filters fields in the map where the value contains the specified source index as a prefix. + * + * @param mlBatchIngestionInput The MLBatchIngestionInput. + * @param index The source index to filter by. + * @return A new map with only the entries that match the specified source index. + */ + protected Map filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int index) { + Map fieldMap = mlBatchIngestionInput.getFieldMapping(); + int indexInFieldMap = index + 1; + String prefix = "source[" + indexInFieldMap + "]"; + + Map filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> { + Object value = entry.getValue(); + if (value instanceof String) { + return ((String) value).contains(prefix); + } else if (value instanceof List) { + return ((List) value).stream().anyMatch(val -> val.contains(prefix)); + } + return false; + }).collect(Collectors.toMap(Map.Entry::getKey, entry -> { + Object value = entry.getValue(); + if (value instanceof String) { + return value; + } else if (value instanceof List) { + return ((List) value).stream().filter(val -> val.contains(prefix)).collect(Collectors.toList()); + } + return null; + })); + + if (filteredFieldMap.containsKey(OUTPUT)) { + filteredFieldMap.put(OUTPUT_FIELD_NAMES, fieldMap.get(OUTPUT_FIELD_NAMES)); + } + if (filteredFieldMap.containsKey(INPUT)) { + filteredFieldMap.put(INPUT_FIELD_NAMES, fieldMap.get(INPUT_FIELD_NAMES)); + } + return filteredFieldMap; + } + + /** + * Produce the source as a Map to be ingested in to OpenSearch. + * + * @param jsonStr The MLBatchIngestionInput. + * @param fieldMapping The field mapping that includes all the field name and Json Path for the data. + * @return A new map that contains all the fields and data for ingestion. + */ + protected Map processFieldMapping(String jsonStr, Map fieldMapping) { + String inputJsonPath = fieldMapping.containsKey(INPUT) ? getJsonPath((String) fieldMapping.get(INPUT)) : null; + List remoteModelInput = inputJsonPath != null ? (List) JsonPath.read(jsonStr, inputJsonPath) : null; + List inputFieldNames = inputJsonPath != null ? (List) fieldMapping.get(INPUT_FIELD_NAMES) : null; + + String outputJsonPath = fieldMapping.containsKey(OUTPUT) ? getJsonPath((String) fieldMapping.get(OUTPUT)) : null; + List remoteModelOutput = outputJsonPath != null ? (List) JsonPath.read(jsonStr, outputJsonPath) : null; + List outputFieldNames = outputJsonPath != null ? (List) fieldMapping.get(OUTPUT_FIELD_NAMES) : null; + + List ingestFieldsJsonPath = Optional + .ofNullable((List) fieldMapping.get(INGEST_FIELDS)) + .stream() + .flatMap(Collection::stream) + .map(StringUtils::getJsonPath) + .collect(Collectors.toList()); + + Map jsonMap = new HashMap<>(); + + populateJsonMap(jsonMap, inputFieldNames, remoteModelInput); + populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput); + + for (String fieldPath : ingestFieldsJsonPath) { + jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath)); + } + + if (fieldMapping.containsKey(ID_FIELD)) { + List docIdJsonPath = Optional + .ofNullable((List) fieldMapping.get(ID_FIELD)) + .stream() + .flatMap(Collection::stream) + .map(StringUtils::getJsonPath) + .collect(Collectors.toList()); + if (docIdJsonPath.size() != 1) { + throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source"); + } + jsonMap.put("_id", JsonPath.read(jsonStr, docIdJsonPath.get(0))); + } + return jsonMap; + } + + protected void batchIngest( + List sourceLines, + MLBatchIngestionInput mlBatchIngestionInput, + ActionListener bulkResponseListener, + int sourceIndex, + boolean isSoleSource + ) { + BulkRequest bulkRequest = new BulkRequest(); + sourceLines.stream().forEach(jsonStr -> { + Map filteredMapping = isSoleSource + ? mlBatchIngestionInput.getFieldMapping() + : filterFieldMapping(mlBatchIngestionInput, sourceIndex); + Map jsonMap = processFieldMapping(jsonStr, filteredMapping); + if (isSoleSource || sourceIndex == 0) { + IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()); + if (jsonMap.containsKey("_id")) { + String id = (String) jsonMap.remove("_id"); + indexRequest.id(id); + } + indexRequest.source(jsonMap); + bulkRequest.add(indexRequest); + } else { + // bulk update docs as they were partially ingested + if (!jsonMap.containsKey("_id")) { + throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources"); + } + String id = (String) jsonMap.remove("_id"); + UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap); + bulkRequest.add(updateRequest); + } + }); + client.bulk(bulkRequest, bulkResponseListener); + } + + private void populateJsonMap(Map jsonMap, List fieldNames, List modelData) { + if (modelData != null) { + if (modelData.size() != fieldNames.size()) { + throw new IllegalArgumentException("The fieldMapping and source data do not match"); + } + + for (int index = 0; index < modelData.size(); index++) { + jsonMap.put(fieldNames.get(index), modelData.get(index)); + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java new file mode 100644 index 0000000000..e020dcdd60 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; + +public interface Ingestable { + /** + * offline ingest data with given input. + * @param mlBatchIngestionInput batch ingestion input data + * @return successRate (0 - 100) + */ + default double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + throw new IllegalStateException("Ingest is not implemented"); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java new file mode 100644 index 0000000000..8dc94894ef --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.engine.annotation.Ingester; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Ingester("openai") +public class OpenAIDataIngestion extends AbstractIngestion { + private static final String API_KEY = "openAI_key"; + private static final String API_URL = "https://api.openai.com/v1/files/"; + public static final String SOURCE = "source"; + + public OpenAIDataIngestion(Client client) { + super(client); + } + + @Override + public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + List sources = (List) mlBatchIngestionInput.getDataSources().get(SOURCE); + if (Objects.isNull(sources) || sources.isEmpty()) { + return 100; + } + + boolean isSoleSource = sources.size() == 1; + List successRates = Collections.synchronizedList(new ArrayList<>()); + for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) { + successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource)); + } + + return calculateSuccessRate(successRates); + } + + private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) { + double successRate = 0; + try { + String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY); + URL url = new URL(API_URL + fileId + "/content"); + + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Authorization", "Bearer " + apiKey); + + try ( + InputStreamReader inputStreamReader = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> new InputStreamReader(connection.getInputStream())); + BufferedReader reader = new BufferedReader(inputStreamReader) + ) { + List linesBuffer = new ArrayList<>(); + String line; + int lineCount = 0; + // Atomic counters for tracking success and failure + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + // List of CompletableFutures to track batch ingestion operations + List> futures = new ArrayList<>(); + + while ((line = reader.readLine()) != null) { + linesBuffer.add(line); + lineCount++; + + // Process every 100 lines + if (lineCount % 100 == 0) { + // Create a CompletableFuture that will be completed by the bulkResponseListener + CompletableFuture future = new CompletableFuture<>(); + batchIngest( + linesBuffer, + mlBatchIngestionInput, + getBulkResponseListener(successfulBatches, failedBatches, future), + sourceIndex, + isSoleSource + ); + + futures.add(future); + linesBuffer.clear(); + } + } + // Process any remaining lines in the buffer + if (!linesBuffer.isEmpty()) { + CompletableFuture future = new CompletableFuture<>(); + batchIngest( + linesBuffer, + mlBatchIngestionInput, + getBulkResponseListener(successfulBatches, failedBatches, future), + sourceIndex, + isSoleSource + ); + futures.add(future); + } + + reader.close(); + // Combine all futures and wait for completion + CompletableFuture allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + // Wait for all tasks to complete + allFutures.join(); + int totalBatches = successfulBatches.get() + failedBatches.get(); + successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100; + } + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to read from OpenAI file API: ", e); + } catch (Exception e) { + log.error(e.getMessage()); + throw new OpenSearchStatusException("Failed to batch ingest: " + e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR); + } + + return successRate; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java new file mode 100644 index 0000000000..b6fb3e1226 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java @@ -0,0 +1,206 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.engine.annotation.Ingester; + +import com.google.common.annotations.VisibleForTesting; + +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +@Log4j2 +@Ingester("s3") +public class S3DataIngestion extends AbstractIngestion { + public static final String SOURCE = "source"; + + public S3DataIngestion(Client client) { + super(client); + } + + @Override + public double ingest(MLBatchIngestionInput mlBatchIngestionInput) { + S3Client s3 = initS3Client(mlBatchIngestionInput); + + List s3Uris = (List) mlBatchIngestionInput.getDataSources().get(SOURCE); + if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) { + return 100; + } + boolean isSoleSource = s3Uris.size() == 1; + List successRates = Collections.synchronizedList(new ArrayList<>()); + for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) { + successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource)); + } + + return calculateSuccessRate(successRates); + } + + public double ingestSingleSource( + S3Client s3, + String s3Uri, + MLBatchIngestionInput mlBatchIngestionInput, + int sourceIndex, + boolean isSoleSource + ) { + String bucketName = getS3BucketName(s3Uri); + String keyName = getS3KeyName(s3Uri); + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build(); + double successRate = 0; + + try ( + ResponseInputStream s3is = AccessController + .doPrivileged((PrivilegedExceptionAction>) () -> s3.getObject(getObjectRequest)); + BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8)) + ) { + List linesBuffer = new ArrayList<>(); + String line; + int lineCount = 0; + // Atomic counters for tracking success and failure + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + // List of CompletableFutures to track batch ingestion operations + List> futures = new ArrayList<>(); + + while ((line = reader.readLine()) != null) { + linesBuffer.add(line); + lineCount++; + + // Process every 100 lines + if (lineCount % 100 == 0) { + // Create a CompletableFuture that will be completed by the bulkResponseListener + CompletableFuture future = new CompletableFuture<>(); + batchIngest( + linesBuffer, + mlBatchIngestionInput, + getBulkResponseListener(successfulBatches, failedBatches, future), + sourceIndex, + isSoleSource + ); + + futures.add(future); + linesBuffer.clear(); + } + } + // Process any remaining lines in the buffer + if (!linesBuffer.isEmpty()) { + CompletableFuture future = new CompletableFuture<>(); + batchIngest( + linesBuffer, + mlBatchIngestionInput, + getBulkResponseListener(successfulBatches, failedBatches, future), + sourceIndex, + isSoleSource + ); + futures.add(future); + } + + reader.close(); + + // Combine all futures and wait for completion + CompletableFuture allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + + // Wait for all tasks to complete + allFutures.join(); + + int totalBatches = successfulBatches.get() + failedBatches.get(); + successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100; + } catch (S3Exception e) { + log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage()); + throw e; + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to get S3 Object: ", e); + } catch (Exception e) { + log.error(e.getMessage()); + throw new OpenSearchStatusException("Failed to batch ingest: " + e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR); + } finally { + s3.close(); + } + + return successRate; + } + + private String getS3BucketName(String s3Uri) { + // Remove the "s3://" prefix + String uriWithoutPrefix = s3Uri.substring(5); + // Find the first slash after the bucket name + int slashIndex = uriWithoutPrefix.indexOf('/'); + // If there is no slash, the entire remaining string is the bucket name + if (slashIndex == -1) { + return uriWithoutPrefix; + } + // Otherwise, the bucket name is the substring up to the first slash + return uriWithoutPrefix.substring(0, slashIndex); + } + + private String getS3KeyName(String s3Uri) { + String uriWithoutPrefix = s3Uri.substring(5); + // Find the first slash after the bucket name + int slashIndex = uriWithoutPrefix.indexOf('/'); + // If there is no slash, it means there is no key, return an empty string or handle as needed + if (slashIndex == -1) { + return ""; + } + // The key name is the substring after the first slash + return uriWithoutPrefix.substring(slashIndex + 1); + } + + @VisibleForTesting + public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) { + String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD); + String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD); + String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD); + String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD); + + AwsCredentials credentials = sessionToken == null + ? AwsBasicCredentials.create(accessKey, secretKey) + : AwsSessionCredentials.create(accessKey, secretKey, sessionToken); + + try { + S3Client s3 = AccessController + .doPrivileged( + (PrivilegedExceptionAction) () -> S3Client + .builder() + .region(Region.of(region)) // Specify the region here + .credentialsProvider(StaticCredentialsProvider.create(credentials)) + .build() + ); + return s3; + } catch (PrivilegedActionException e) { + throw new RuntimeException("Can't load credentials", e); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index f6c9b76071..44d3f104cb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -326,7 +326,6 @@ public void test_onComplete_error_http_status() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - System.out.println(captor.getValue().getMessage()); assert captor.getValue().getMessage().contains("runtime error"); } @@ -350,7 +349,6 @@ public void test_onComplete_throttle_error_headers() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - System.out.println(captor.getValue().getMessage()); assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java new file mode 100644 index 0000000000..d2f66dacbc --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java @@ -0,0 +1,262 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.ID_FIELD; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INGEST_FIELDS; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT_FIELD_NAMES; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT_FIELD_NAMES; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; + +public class AbstractIngestionTests { + @Mock + Client client; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + S3DataIngestion s3DataIngestion = new S3DataIngestion(client); + + Map fieldMap; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + s3DataIngestion = new S3DataIngestion(client); + + fieldMap = new HashMap<>(); + fieldMap.put(INPUT, "source[1].$.content"); + fieldMap.put(OUTPUT, "source[1].$.SageMakerOutput"); + fieldMap.put(INPUT_FIELD_NAMES, Arrays.asList("chapter", "title")); + fieldMap.put(OUTPUT_FIELD_NAMES, Arrays.asList("chapter_embedding", "title_embedding")); + fieldMap.put(INGEST_FIELDS, Arrays.asList("source[1].$.id")); + } + + @Test + public void testBulkResponseListener_Success() { + // Arrange + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); + + // Mock BulkResponse + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + + S3DataIngestion instance = new S3DataIngestion(client); + + // Act + ActionListener listener = instance.getBulkResponseListener(successfulBatches, failedBatches, future); + listener.onResponse(bulkResponse); + + // Assert + assertFalse(future.isCompletedExceptionally()); + assertEquals(1, successfulBatches.get()); + assertEquals(0, failedBatches.get()); + } + + @Test + public void testBulkResponseListener_Failure() { + // Arrange + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); + + // Mock BulkResponse + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(true); + when(bulkResponse.buildFailureMessage()).thenReturn("Failure message"); + + S3DataIngestion instance = new S3DataIngestion(client); + + // Act + ActionListener listener = instance.getBulkResponseListener(successfulBatches, failedBatches, future); + listener.onResponse(bulkResponse); + + // Assert + assertTrue(future.isCompletedExceptionally()); + assertEquals(0, successfulBatches.get()); + assertEquals(1, failedBatches.get()); + } + + @Test + public void testBulkResponseListener_Exception() { + // Arrange + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); + + // Create an exception + RuntimeException exception = new RuntimeException("Test exception"); + + S3DataIngestion instance = new S3DataIngestion(client); + + // Act + ActionListener listener = instance.getBulkResponseListener(successfulBatches, failedBatches, future); + listener.onFailure(exception); + + // Assert + assertTrue(future.isCompletedExceptionally()); + assertEquals(0, successfulBatches.get()); + assertEquals(1, failedBatches.get()); + assertThrows(Exception.class, () -> future.join()); // Ensure that future throws exception + } + + @Test + public void testCalculateSuccessRate_MultipleValues() { + // Arrange + List successRates = Arrays.asList(90.0, 85.5, 92.0, 88.0); + + // Act + double result = s3DataIngestion.calculateSuccessRate(successRates); + + // Assert + assertEquals(85.5, result, 0.0001); + } + + @Test + public void testCalculateSuccessRate_SingleValue() { + // Arrange + List successRates = Collections.singletonList(99.9); + + // Act + double result = s3DataIngestion.calculateSuccessRate(successRates); + + // Assert + assertEquals(99.9, result, 0.0001); + } + + @Test + public void testFilterFieldMapping_ValidInput_MatchingPrefix() { + // Arrange + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput("indexName", fieldMap, new HashMap<>(), new HashMap<>()); + Map result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0); + + // Assert + assertEquals(5, result.size()); + assertEquals("source[1].$.content", result.get(INPUT)); + assertEquals("source[1].$.SageMakerOutput", result.get(OUTPUT)); + assertEquals(Arrays.asList("chapter", "title"), result.get(INPUT_FIELD_NAMES)); + assertEquals(Arrays.asList("chapter_embedding", "title_embedding"), result.get(OUTPUT_FIELD_NAMES)); + assertEquals(Arrays.asList("source[1].$.id"), result.get(INGEST_FIELDS)); + } + + @Test + public void testFilterFieldMapping_NoMatchingPrefix() { + // Arrange + Map fieldMap = new HashMap<>(); + fieldMap.put("field1", "source[3].$.response.body.data[*].embedding"); + fieldMap.put("field2", "source[4].$.body.input"); + + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput("indexName", fieldMap, new HashMap<>(), new HashMap<>()); + + // Act + Map result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0); + + // Assert + assertTrue(result.isEmpty()); + } + + @Test + public void testProcessFieldMapping_ValidInput() { + String jsonStr = + "{\"SageMakerOutput\":[[-0.017166402, 0.055771016],[-0.004301484,-0.042826906]],\"content\":[\"this is chapter 1\",\"harry potter\"],\"id\":1}"; + // Arrange + + // Act + Map processedFieldMapping = s3DataIngestion.processFieldMapping(jsonStr, fieldMap); + + // Assert + assertEquals("this is chapter 1", processedFieldMapping.get("chapter")); + assertEquals("harry potter", processedFieldMapping.get("title")); + assertEquals(1, processedFieldMapping.get("id")); + } + + @Test + public void testProcessFieldMapping_NoIdFieldInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The Id field must contains only 1 jsonPath for each source"); + + String jsonStr = + "{\"SageMakerOutput\":[[-0.017166402, 0.055771016],[-0.004301484,-0.042826906]],\"content\":[\"this is chapter 1\",\"harry potter\"],\"id\":1}"; + // Arrange + fieldMap.put(ID_FIELD, null); + + // Act + s3DataIngestion.processFieldMapping(jsonStr, fieldMap); + } + + @Test + public void testBatchIngestSuccess_SoleSource() { + doAnswer(invocation -> { + ActionListener bulkResponseListener = invocation.getArgument(1); + bulkResponseListener.onResponse(mock(BulkResponse.class)); + return null; + }).when(client).bulk(any(), any()); + + List sourceLines = Arrays + .asList( + "{\"SageMakerOutput\":[[-0.017166402, 0.055771016],[-0.004301484,-0.042826906]],\"content\":[\"this is chapter 1\",\"harry potter\"],\"id\":1}" + ); + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput("indexName", fieldMap, new HashMap<>(), new HashMap<>()); + ActionListener bulkResponseListener = mock(ActionListener.class); + s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, true); + + verify(client).bulk(isA(BulkRequest.class), isA(ActionListener.class)); + verify(bulkResponseListener).onResponse(isA(BulkResponse.class)); + } + + @Test + public void testBatchIngestSuccess_NoIdError() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The id filed must be provided to match documents for multiple sources"); + + doAnswer(invocation -> { + ActionListener bulkResponseListener = invocation.getArgument(1); + bulkResponseListener.onResponse(mock(BulkResponse.class)); + return null; + }).when(client).bulk(any(), any()); + + List sourceLines = Arrays + .asList( + "{\"SageMakerOutput\":[[-0.017166402, 0.055771016],[-0.004301484,-0.042826906]],\"content\":[\"this is chapter 1\",\"harry potter\"],\"id\":1}" + ); + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput("indexName", fieldMap, new HashMap<>(), new HashMap<>()); + ActionListener bulkResponseListener = mock(ActionListener.class); + s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 1, false); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/S3DataIngestionTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/S3DataIngestionTests.java new file mode 100644 index 0000000000..4bf2ffd58f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/S3DataIngestionTests.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.ingest; + +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INGEST_FIELDS; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT_FIELD_NAMES; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT_FIELD_NAMES; +import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; + +import software.amazon.awssdk.services.s3.S3Client; + +public class S3DataIngestionTests { + + private MLBatchIngestionInput mlBatchIngestionInput; + private S3DataIngestion s3DataIngestion; + + @Mock + Client client; + + @Mock + S3Client s3Client; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + s3DataIngestion = new S3DataIngestion(client); + + Map fieldMap = new HashMap<>(); + fieldMap.put("input", "$.content"); + fieldMap.put("output", "$.SageMakerOutput"); + fieldMap.put(INPUT_FIELD_NAMES, Arrays.asList("chapter", "title")); + fieldMap.put(OUTPUT_FIELD_NAMES, Arrays.asList("chapter_embedding", "title_embedding")); + fieldMap.put(INGEST_FIELDS, Arrays.asList("$.id")); + + Map credential = Map + .of("region", "us-east-1", "access_key", "some accesskey", "secret_key", "some secret", "session_token", "some token"); + Map dataSource = new HashMap<>(); + dataSource.put("type", "s3"); + dataSource.put(SOURCE, Arrays.asList("s3://offlinebatch/output/sagemaker_djl_batch_input.json.out")); + + mlBatchIngestionInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(fieldMap) + .credential(credential) + .dataSources(dataSource) + .build(); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java new file mode 100644 index 0000000000..cf03d0f11a --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.batch; + +import static org.opensearch.ml.common.MLTask.ERROR_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL; +import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.ingest.Ingestable; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.MLExceptionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TransportBatchIngestionAction extends HandledTransportAction { + private static final String S3_URI_REGEX = "^s3://([a-zA-Z0-9.-]+)(/.*)?$"; + private static final Pattern S3_URI_PATTERN = Pattern.compile(S3_URI_REGEX); + public static final String TYPE = "type"; + public static final String SOURCE = "source"; + TransportService transportService; + MLTaskManager mlTaskManager; + private final Client client; + private ThreadPool threadPool; + + @Inject + public TransportBatchIngestionAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + MLTaskManager mlTaskManager, + ThreadPool threadPool + ) { + super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new); + this.transportService = transportService; + this.client = client; + this.mlTaskManager = mlTaskManager; + this.threadPool = threadPool; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLBatchIngestionRequest mlBatchIngestionRequest = MLBatchIngestionRequest.fromActionRequest(request); + MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput(); + try { + validateBatchIngestInput(mlBatchIngestionInput); + MLTask mlTask = MLTask + .builder() + .async(true) + .taskType(MLTaskType.BATCH_INGEST) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .state(MLTaskState.CREATED) + .build(); + + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + try { + mlTask.setTaskId(taskId); + mlTaskManager.add(mlTask); + listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name())); + String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE); + Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class); + threadPool.executor(TRAIN_THREAD_POOL).execute(() -> { + double successRate = ingestable.ingest(mlBatchIngestionInput); + handleSuccessRate(successRate, taskId); + }); + } catch (Exception ex) { + log.error("Failed in batch ingestion", ex); + mlTaskManager + .updateMLTask( + taskId, + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), + TASK_SEMAPHORE_TIMEOUT, + true + ); + listener.onFailure(ex); + } + }, exception -> { + log.error("Failed to create batch ingestion task", exception); + listener.onFailure(exception); + })); + } catch (IllegalArgumentException e) { + log.error(e.getMessage()); + listener + .onFailure( + new OpenSearchStatusException( + "IllegalArgumentException in the batch ingestion input: " + e.getMessage(), + RestStatus.BAD_REQUEST + ) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void handleSuccessRate(double successRate, String taskId) { + if (successRate == 100) { + mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED), 5000, true); + } else if (successRate > 0) { + mlTaskManager + .updateMLTask( + taskId, + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "batch ingestion successful rate is " + successRate), + TASK_SEMAPHORE_TIMEOUT, + true + ); + } else { + mlTaskManager + .updateMLTask( + taskId, + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "batch ingestion successful rate is 0"), + TASK_SEMAPHORE_TIMEOUT, + true + ); + } + } + + private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInput) { + if (mlBatchIngestionInput == null + || mlBatchIngestionInput.getDataSources() == null + || mlBatchIngestionInput.getDataSources().isEmpty()) { + throw new IllegalArgumentException("The batch ingest input data source cannot be null"); + } + Map dataSources = mlBatchIngestionInput.getDataSources(); + if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) { + throw new IllegalArgumentException("The batch ingest input data source is missing data type or source"); + } + if (((String) dataSources.get(TYPE)).toLowerCase() == "s3") { + List s3Uris = (List) dataSources.get(SOURCE); + if (s3Uris == null || s3Uris.isEmpty()) { + throw new IllegalArgumentException("The batch ingest input s3Uris is empty"); + } + + // Partition the list into valid and invalid URIs + Map> partitionedUris = s3Uris + .stream() + .collect(Collectors.partitioningBy(uri -> S3_URI_PATTERN.matcher(uri).matches())); + + List invalidUris = partitionedUris.get(false); + + if (!invalidUris.isEmpty()) { + throw new IllegalArgumentException("The following batch ingest input S3 URIs are invalid: " + invalidUris); + } + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 7e9ab6d940..b4abd328c7 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -48,6 +48,7 @@ import org.opensearch.ml.action.agents.GetAgentTransportAction; import org.opensearch.ml.action.agents.TransportRegisterAgentAction; import org.opensearch.ml.action.agents.TransportSearchAgentAction; +import org.opensearch.ml.action.batch.TransportBatchIngestionAction; import org.opensearch.ml.action.config.GetConfigTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; import org.opensearch.ml.action.connector.ExecuteConnectorTransportAction; @@ -120,6 +121,7 @@ import org.opensearch.ml.common.transport.agent.MLAgentGetAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLSearchAgentAction; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; import org.opensearch.ml.common.transport.config.MLConfigGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; @@ -216,6 +218,7 @@ import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; +import org.opensearch.ml.rest.RestMLBatchIngestAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; @@ -440,7 +443,8 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class), new ActionHandler<>(MLListToolsAction.INSTANCE, ListToolsTransportAction.class), new ActionHandler<>(MLGetToolAction.INSTANCE, GetToolTransportAction.class), - new ActionHandler<>(MLConfigGetAction.INSTANCE, GetConfigTransportAction.class) + new ActionHandler<>(MLConfigGetAction.INSTANCE, GetConfigTransportAction.class), + new ActionHandler<>(MLBatchIngestionAction.INSTANCE, TransportBatchIngestionAction.class) ); } @@ -759,6 +763,7 @@ public List getRestHandlers( RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(); + RestMLBatchIngestAction restMLBatchIngestAction = new RestMLBatchIngestAction(); return ImmutableList .of( restMLStatsAction, @@ -811,7 +816,8 @@ public List getRestHandlers( restMLSearchAgentAction, restMLListToolsAction, restMLGetToolAction, - restMLGetConfigAction + restMLGetConfigAction, + restMLBatchIngestAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLBatchIngestAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLBatchIngestAction.java new file mode 100644 index 0000000000..cdf2380985 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLBatchIngestAction.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class RestMLBatchIngestAction extends BaseRestHandler { + private static final String ML_BATCH_INGESTION_ACTION = "ml_batch_ingestion_action"; + + @Override + public String getName() { + return ML_BATCH_INGESTION_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/_batch_ingestion", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLBatchIngestionRequest mlBatchIngestTaskRequest = getRequest(request); + return channel -> client.execute(MLBatchIngestionAction.INSTANCE, mlBatchIngestTaskRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLBatchIngestTaskRequest from a RestRequest + * + * @param request RestRequest + * @return MLBatchIngestTaskRequest + */ + @VisibleForTesting + MLBatchIngestionRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new IOException("Batch Ingestion request has empty body"); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLBatchIngestionInput mlBatchIngestionInput = MLBatchIngestionInput.parse(parser); + return new MLBatchIngestionRequest(mlBatchIngestionInput); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 82c72e11a2..72b841eb7b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -127,13 +127,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); - System.out.println("actionType is " + actionType); if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } else if (!ActionType.isValidActionInModelPrediction(actionType)) { - System.out.println(actionType.toString()); throw new IllegalArgumentException("Wrong action type in the rest request path!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 5f5f567eb8..962200c5d0 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -319,7 +319,6 @@ public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionLi */ public static String getActionTypeFromRestRequest(RestRequest request) { String path = request.path(); - System.out.println("path is " + path); String[] segments = path.split("/"); String methodName = segments[segments.length - 1]; methodName = methodName.startsWith("_") ? methodName.substring(1) : methodName; diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index 99cf437d24..1914fd5eb2 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -24,4 +24,10 @@ grant { // Circuit Breaker permission java.lang.RuntimePermission "getFileSystemAttributes"; + + // s3 client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect,resolve"; + + // aws credential file access + permission java.io.FilePermission "<>", "read"; }; diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java new file mode 100644 index 0000000000..7b3766dadf --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -0,0 +1,254 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.batch; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.MLTask.ERROR_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INGEST_FIELDS; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT_FIELD_NAMES; +import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT_FIELD_NAMES; +import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE; +import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportBatchIngestionActionTests extends OpenSearchTestCase { + @Mock + private Client client; + @Mock + private TransportService transportService; + @Mock + private MLTaskManager mlTaskManager; + @Mock + private ActionFilters actionFilters; + @Mock + private MLBatchIngestionRequest mlBatchIngestionRequest; + @Mock + private Task task; + @Mock + ActionListener actionListener; + @Mock + ThreadPool threadPool; + + private TransportBatchIngestionAction batchAction; + private MLBatchIngestionInput batchInput; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool); + + Map fieldMap = new HashMap<>(); + fieldMap.put("input", "$.content"); + fieldMap.put("output", "$.SageMakerOutput"); + fieldMap.put(INPUT_FIELD_NAMES, Arrays.asList("chapter", "title")); + fieldMap.put(OUTPUT_FIELD_NAMES, Arrays.asList("chapter_embedding", "title_embedding")); + fieldMap.put(INGEST_FIELDS, Arrays.asList("$.id")); + + Map credential = Map + .of("region", "us-east-1", "access_key", "some accesskey", "secret_key", "some secret", "session_token", "some token"); + Map dataSource = new HashMap<>(); + dataSource.put("type", "s3"); + dataSource.put(SOURCE, Arrays.asList("s3://offlinebatch/output/sagemaker_djl_batch_input.json.out")); + + batchInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(fieldMap) + .credential(credential) + .dataSources(dataSource) + .build(); + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + } + + public void test_doExecute_success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + IndexResponse indexResponse = new IndexResponse(shardId, "taskId", 1, 1, 1, true); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + verify(actionListener).onResponse(any(MLBatchIngestionResponse.class)); + } + + public void test_doExecute_handleSuccessRate100() { + batchAction.handleSuccessRate(100, "taskid"); + verify(mlTaskManager).updateMLTask("taskid", Map.of(STATE_FIELD, COMPLETED), 5000, true); + } + + public void test_doExecute_handleSuccessRate50() { + batchAction.handleSuccessRate(50, "taskid"); + verify(mlTaskManager) + .updateMLTask( + "taskid", + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "batch ingestion successful rate is 50.0"), + TASK_SEMAPHORE_TIMEOUT, + true + ); + } + + public void test_doExecute_handleSuccessRate0() { + batchAction.handleSuccessRate(0, "taskid"); + verify(mlTaskManager) + .updateMLTask( + "taskid", + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "batch ingestion successful rate is 0"), + TASK_SEMAPHORE_TIMEOUT, + true + ); + } + + public void test_doExecute_noDataSource() { + MLBatchIngestionInput batchInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(new HashMap<>()) + .credential(new HashMap<>()) + .dataSources(new HashMap<>()) + .build(); + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "IllegalArgumentException in the batch ingestion input: The batch ingest input data source cannot be null", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_doExecute_noTypeInDataSource() { + MLBatchIngestionInput batchInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(new HashMap<>()) + .credential(new HashMap<>()) + .dataSources(Map.of("source", "some url")) + .build(); + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "IllegalArgumentException in the batch ingestion input: The batch ingest input data source is missing data type or source", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_doExecute_invalidS3DataSource() { + Map dataSource = new HashMap<>(); + dataSource.put("type", "s3"); + dataSource.put(SOURCE, Arrays.asList("s3://offlinebatch/output/sagemaker_djl_batch_input.json.out", "invalid s3")); + + MLBatchIngestionInput batchInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(new HashMap<>()) + .credential(new HashMap<>()) + .dataSources(dataSource) + .build(); + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "IllegalArgumentException in the batch ingestion input: The following batch ingest input S3 URIs are invalid: [invalid s3]", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_doExecute_emptyS3DataSource() { + Map dataSource = new HashMap<>(); + dataSource.put("type", "s3"); + dataSource.put(SOURCE, new ArrayList<>()); + + MLBatchIngestionInput batchInput = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(new HashMap<>()) + .credential(new HashMap<>()) + .dataSources(dataSource) + .build(); + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "IllegalArgumentException in the batch ingestion input: The batch ingest input s3Uris is empty", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_doExecute_mlTaskCreateException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Failed to create ML Task")); + return null; + }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to create ML Task", argumentCaptor.getValue().getMessage()); + } + + public void test_doExecute_batchIngestionFailed() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + IndexResponse indexResponse = new IndexResponse(shardId, "taskId", 1, 1, 1, true); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + + doThrow(new OpenSearchStatusException("some error", RestStatus.INTERNAL_SERVER_ERROR)).when(mlTaskManager).add(isA(MLTask.class)); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("some error", argumentCaptor.getValue().getMessage()); + verify(mlTaskManager).updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "some error"), TASK_SEMAPHORE_TIMEOUT, true); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java new file mode 100644 index 0000000000..9b6a00c8d7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.TestHelper.getBatchIngestionRestRequest; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; +import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLBatchIngestionActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private RestMLBatchIngestAction restMLBatchIngestAction; + private ThreadPool threadPool; + NodeClient client; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + restMLBatchIngestAction = new RestMLBatchIngestAction(); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLBatchIngestionAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLBatchIngestAction mlBatchIngestAction = new RestMLBatchIngestAction(); + assertNotNull(mlBatchIngestAction); + } + + public void testGetName() { + String actionName = restMLBatchIngestAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_batch_ingestion_action", actionName); + } + + public void testRoutes() { + List routes = restMLBatchIngestAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/_batch_ingestion", route.getPath()); + } + + public void testGetRequest() throws IOException { + RestRequest request = getBatchIngestionRestRequest(); + MLBatchIngestionRequest mlBatchIngestionRequest = restMLBatchIngestAction.getRequest(request); + + MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput(); + assertEquals("test batch index", mlBatchIngestionInput.getIndexName()); + assertEquals("$.content", mlBatchIngestionInput.getFieldMapping().get("input")); + assertNotNull(mlBatchIngestionInput.getDataSources().get("source")); + assertNotNull(mlBatchIngestionInput.getCredential()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getBatchIngestionRestRequest(); + restMLBatchIngestAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLBatchIngestionRequest.class); + verify(client, times(1)).execute(eq(MLBatchIngestionAction.INSTANCE), argumentCaptor.capture(), any()); + MLBatchIngestionInput mlBatchIngestionInput = argumentCaptor.getValue().getMlBatchIngestionInput(); + assertEquals("test batch index", mlBatchIngestionInput.getIndexName()); + assertEquals("$.content", mlBatchIngestionInput.getFieldMapping().get("input")); + assertNotNull(mlBatchIngestionInput.getDataSources().get("source")); + assertNotNull(mlBatchIngestionInput.getCredential()); + } + + public void testPrepareRequest_EmptyContent() throws Exception { + thrown.expect(IOException.class); + thrown.expectMessage("Batch Ingestion request has empty body"); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + restMLBatchIngestAction.handleRequest(request, channel, client); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index ca5046fa0b..26f0afa8fb 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -523,4 +523,27 @@ public static void copyFile(String sourceFile, String destFile) throws IOExcepti FileUtils.copyFile(new File(sourceFile), new File(destFile)); } + public static RestRequest getBatchIngestionRestRequest() { + final String requestContent = "{\n" + + " \"index_name\": \"test batch index\",\n" + + " \"field_map\": {\n" + + " \"input\": \"$.content\",\n" + + " \"output\": \"$.SageMakerOutput\",\n" + + " \"input_names\": [\"chapter\", \"title\"],\n" + + " \"output_names\": [\"chapter_embedding\", \"title_embedding\"],\n" + + " \"ingest_fields\": [\"$.id\"]\n" + + " },\n" + + " \"credential\": {\n" + + " \"region\": \"xxxxxxxx\"\n" + + " },\n" + + " \"data_source\": {\n" + + " \"type\": \"s3\",\n" + + " \"source\": [\"s3://offlinebatch/output/sagemaker_djl_batch_input.json.out\"]\n" + + " }\n" + + "}"; + RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } }