Skip to content

Commit

Permalink
Add multi modal default preprocess function (#2500)
Browse files Browse the repository at this point in the history
* Add multi modal default preprocess function

Signed-off-by: zane-neo <zaniu@amazon.com>

* Address comments

Signed-off-by: zane-neo <zaniu@amazon.com>

* address comments

Signed-off-by: zane-neo <zaniu@amazon.com>

* add IT

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix IT

Signed-off-by: zane-neo <zaniu@amazon.com>

* Update common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: zane-neo <zaniu@amazon.com>

* fix test

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* Add more ITs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure ITs

Signed-off-by: zane-neo <zaniu@amazon.com>

* fix failure IT

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure ITs

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add error response to make it esay to figure out the failure root cause

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* rebase main

Signed-off-by: zane-neo <zaniu@amazon.com>

---------

Signed-off-by: zane-neo <zaniu@amazon.com>
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
zane-neo and ylwu-amzn authored Jul 17, 2024
1 parent 2058d76 commit 0e89c17
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -22,6 +23,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
Expand All @@ -34,7 +36,9 @@ public class MLPreProcessFunction {
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,35 @@
import org.opensearch.script.TemplateScript;

import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

/**
* This abstract class represents a pre-processing function for a connector.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

/**
* This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet.
* If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed.
*/
protected boolean returnDirectlyForRemoteInferenceInput;

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param mlInput the MLInput object to be processed
* @return the RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(MLInput mlInput) {
if (mlInput == null) {
Expand All @@ -42,9 +61,17 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
* @param mlInput the input data to be validated
* @throws IllegalArgumentException if the input dataset is not an instance of TextDocsInputDataSet
* or if there is no input text or image provided
*/
public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName()));
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'");
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

/**
* This class provides a pre-processing function for multi-modal input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction {

public MultiModalConnectorPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input text or image provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
* The inputText will always show up in the first document, even it's null.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
parametersMap.put("inputText", inputData.getDocs().get(0));
if (inputData.getDocs().size() > 1) {
parametersMap.put("inputImage", inputData.getDocs().get(1));
}
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

public class MultiModalConnectorPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

MultiModalConnectorPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
RemoteInferenceInputDataSet remoteInferenceInputDataSet;

MLInput textEmbeddingInput;
MLInput textSimilarityInput;
MLInput remoteInferenceInput;

@Before
public void setUp() {
function = new MultiModalConnectorPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
}

@Test
public void testProcess_whenNullInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void testProcess_whenWrongInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void testProcess_whenCorrectInput_expectCorrectOutput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputText"));
assertEquals("world", dataSet.getParameters().get("inputImage"));
}

@Test
public void testProcess_whenInputTextOnly_expectInputTextShowUp() {
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputText"));
}

@Test
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No input text or image provided");
List<String> docs = new ArrayList<>();
docs.add(null);
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
}

@Test
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.client.RestClient;
import org.opensearch.client.RestClientBuilder;
import org.opensearch.common.io.PathUtils;
Expand Down Expand Up @@ -103,6 +104,9 @@
import com.google.gson.Gson;
import com.google.gson.JsonArray;

import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {
protected Gson gson = new Gson();
public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds
Expand Down Expand Up @@ -915,8 +919,14 @@ public Map predictTextEmbedding(String modelId) throws IOException {

public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException {
String requestBody = TestHelper.toJsonString(input);
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
Response response = null;
try {
response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
} catch (ResponseException e) {
log.error(e.getMessage(), e);
response = e.getResponse();
}
return parseResponseToMap(response);
}

Expand Down
Loading

0 comments on commit 0e89c17

Please sign in to comment.