Skip to content

Commit

Permalink
Support one_to_one in ML Inference Search Response Processor (#2801)
Browse files Browse the repository at this point in the history
* add one document to one prediction support

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* rephrase javadoc

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* use OpenSearchStatusException in error handling

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* fix message

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* add more tests

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* handle different exceptions properly

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

---------

Signed-off-by: Mingshi Liu <mingshl@amazon.com>
  • Loading branch information
mingshl authored Aug 21, 2024
1 parent b51f47f commit 2a33c65
Show file tree
Hide file tree
Showing 3 changed files with 2,211 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -33,16 +35,19 @@
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.utils.MapUtils;
import org.opensearch.ml.utils.SearchResponseUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
Expand Down Expand Up @@ -125,9 +130,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
/**
* Processes the search response asynchronously by rewriting the documents with the inference results.
*
* @param request the search request
* @param response the search response
* @param responseContext the pipeline processing context
* By default, it processes multiple documents in a single prediction through the rewriteResponseDocuments method.
* However, when processing one document per inference, it separates the N-hits search response into N one-hit search responses,
* executes the same rewriteResponseDocument method for each one-hit search response,
* and after receiving N one-hit search responses with inference results,
* it combines them back into a single N-hits search response.
*
* @param request the search request
* @param response the search response
* @param responseContext the pipeline processing context
* @param responseListener the listener to be notified when the response is processed
*/
@Override
Expand All @@ -144,20 +155,130 @@ public void processResponseAsync(
responseListener.onResponse(response);
return;
}
rewriteResponseDocuments(response, responseListener);

// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
rewriteResponseDocuments(response, responseListener);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
response,
responseListener,
hits
);
AtomicBoolean isOneHitListenerFailed = new AtomicBoolean(false);
;
for (SearchHit hit : hits) {
SearchHit[] newHits = new SearchHit[1];
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
// if any OneHitListener failure, try stop the rest of the predictions
if (isOneHitListenerFailed.get()) {
break;
}
}
}

} catch (Exception e) {
if (ignoreFailure) {
responseListener.onResponse(response);
} else {
responseListener.onFailure(e);
if (e instanceof OpenSearchStatusException) {
responseListener
.onFailure(
new OpenSearchStatusException(
"Failed to process response: " + e.getMessage(),
RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus())
)
);
} else if (e instanceof MLResourceNotFoundException) {
responseListener
.onFailure(new OpenSearchStatusException("Failed to process response: " + e.getMessage(), RestStatus.NOT_FOUND));
} else {
responseListener.onFailure(e);
}
}
}
}

/**
* Creates an ActionListener for a single SearchResponse that delegates its
* onResponse and onFailure callbacks to a GroupedActionListener.
*
* @param combineResponseListener The GroupedActionListener to which the
* onResponse and onFailure callbacks will be
* delegated.
* @param isOneHitListenerFailed
* @return An ActionListener that delegates its callbacks to the provided
* GroupedActionListener.
*/
private static ActionListener<SearchResponse> getOneHitListener(
GroupedActionListener<SearchResponse> combineResponseListener,
AtomicBoolean isOneHitListenerFailed
) {
ActionListener<SearchResponse> oneHitListener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
combineResponseListener.onResponse(response);
}

@Override
public void onFailure(Exception e) {
// if any OneHitListener failure, try stop the rest of the predictions and return
isOneHitListenerFailed.compareAndSet(false, true);
combineResponseListener.onFailure(e);
}
};
return oneHitListener;
}

/**
* Creates a GroupedActionListener that combines the SearchResponses from individual hits
* and constructs a new SearchResponse with the combined hits.
*
* @param response The original SearchResponse containing the hits to be processed.
* @param responseListener The ActionListener to be notified with the combined SearchResponse.
* @param hits The array of SearchHits to be processed.
* @return A GroupedActionListener that combines the SearchResponses and constructs a new SearchResponse.
*/
private GroupedActionListener<SearchResponse> getCombineResponseGroupedActionListener(
SearchResponse response,
ActionListener<SearchResponse> responseListener,
SearchHit[] hits
) {
GroupedActionListener<SearchResponse> combineResponseListener = new GroupedActionListener<>(new ActionListener<>() {
@Override
public void onResponse(Collection<SearchResponse> responseMapCollection) {
SearchHit[] combinedHits = new SearchHit[hits.length];
int i = 0;
for (SearchResponse OneHitResponseAfterInference : responseMapCollection) {
SearchHit[] hitsAfterInference = OneHitResponseAfterInference.getHits().getHits();
combinedHits[i] = hitsAfterInference[0];
i++;
}
SearchResponse oneToOneInferenceSearchResponse = SearchResponseUtil.replaceHits(combinedHits, response);
responseListener.onResponse(oneToOneInferenceSearchResponse);
}

@Override
public void onFailure(Exception e) {
if (ignoreFailure) {
responseListener.onResponse(response);
} else {
responseListener.onFailure(e);
}
}
}, hits.length);
return combineResponseListener;
}

/**
* Rewrite the documents in the search response with the inference results.
*
* @param response the search response
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @throws IOException if an I/O error occurs during the rewriting process
*/
Expand All @@ -168,27 +289,23 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se

// hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction
Map<Integer, Integer> hitCountInPredictions = new HashMap<>();
if (!oneToOne) {
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListenerManyToOne(
response,
responseListener,
processInputMap,
processOutputMap,
hitCountInPredictions
);

GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListenerManyToOne(
rewriteResponseListener,
inputMapSize
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
}
} else {
responseListener.onFailure(new IllegalArgumentException("one to one prediction is not supported yet."));
}
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListener(
response,
responseListener,
processInputMap,
processOutputMap,
hitCountInPredictions
);

GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListener(
rewriteResponseListener,
inputMapSize
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
}
}

/**
Expand All @@ -201,7 +318,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @throws IOException if an I/O error occurs during the prediction process
*/
private void processPredictionsManyToOne(
private void processPredictions(
SearchHit[] hits,
List<Map<String, String>> processInputMap,
int inputMapIndex,
Expand Down Expand Up @@ -242,7 +359,7 @@ private void processPredictionsManyToOne(
Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
}
}
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
Expand All @@ -263,8 +380,7 @@ private void processPredictionsManyToOne(
Object documentValue = entry.getValue();

// when not existed in the map, add into the modelInputParameters map
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);

updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
}
}
}
Expand Down Expand Up @@ -306,18 +422,28 @@ public void onFailure(Exception e) {
});
}

private void updateModelInputParametersManyToOne(
Map<String, Object> modelInputParameters,
String modelInputFieldName,
Object documentValue
) {
if (!modelInputParameters.containsKey(modelInputFieldName)) {
List<Object> documentValueList = new ArrayList<>();
documentValueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValueList);
/**
* Updates the model input parameters map with the given document value.
* If the setting is one-to-one,
* simply put the document value in the map
* If the setting is many-to-one,
* create a new list and add the document value
* @param modelInputParameters The map containing the model input parameters.
* @param modelInputFieldName The name of the model input field.
* @param documentValue The value from the document that needs to be added to the model input parameters.
*/
private void updateModelInputParameters(Map<String, Object> modelInputParameters, String modelInputFieldName, Object documentValue) {
if (!this.oneToOne) {
if (!modelInputParameters.containsKey(modelInputFieldName)) {
List<Object> documentValueList = new ArrayList<>();
documentValueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValueList);
} else {
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
valueList.add(documentValue);
}
} else {
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
valueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValue);
}
}

Expand All @@ -328,7 +454,7 @@ private void updateModelInputParametersManyToOne(
* @param inputMapSize the size of the input map
* @return a grouped action listener for batch predictions
*/
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListenerManyToOne(
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener,
int inputMapSize
) {
Expand All @@ -353,14 +479,14 @@ public void onFailure(Exception e) {
/**
* Creates an action listener for rewriting the response with the inference results.
*
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param processInputMap the list of input mappings
* @param processOutputMap the list of output mappings
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param processInputMap the list of input mappings
* @param processOutputMap the list of output mappings
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @return an action listener for rewriting the response with the inference results
*/
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListenerManyToOne(
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListener(
SearchResponse response,
ActionListener<SearchResponse> responseListener,
List<Map<String, String>> processInputMap,
Expand Down Expand Up @@ -392,7 +518,7 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);

boolean isModelInputMissing = false;
if (processInputMap != null) {
if (processInputMap != null && !processInputMap.isEmpty()) {
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
}
if (!isModelInputMissing) {
Expand Down Expand Up @@ -499,10 +625,10 @@ private boolean checkIsModelInputMissing(Map<String, Object> document, Map<Strin
* <p>If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex
* is returned.
*
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
* @param processOutputMap the list of output mappings, can be null or empty
* @return a Map containing the output mapping, either the default mapping or the mapping at the
* specified index
* specified index
*/
private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex, List<Map<String, String>> processOutputMap) {
Map<String, String> outputMapping;
Expand All @@ -524,11 +650,11 @@ private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex,
* <p>If the processInputMap is not null and not empty, the mapping at the specified mappingIndex
* is returned.
*
* @param sourceAsMap the source map containing the input data
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
* @param sourceAsMap the source map containing the input data
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
* @param processInputMap the list of input mappings, can be null or empty
* @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or
* the mapping at the specified index
* the mapping at the specified index
*/
private static Map<String, String> getDefaultInputMapping(
Map<String, Object> sourceAsMap,
Expand Down
Loading

0 comments on commit 2a33c65

Please sign in to comment.