Skip to content

Commit

Permalink
add debug log
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am committed Jan 12, 2024
1 parent 010ae7f commit 985f1d9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void addNewVectorField(String index, String modelId, ActionListener<Boole
}

Check warning on line 183 in src/main/java/org/opensearch/agent/indices/IndicesHelper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/indices/IndicesHelper.java#L183

Added line #L183 was not covered by tests

private Long getDimension(String modelId) {
return mlClients.getEmbeddingResult(modelId, List.of("today is sunny"), mlTaskResponse -> {
return mlClients.getEmbeddingResult(modelId, List.of("today is sunny"), true, mlTaskResponse -> {
ModelTensorOutput tensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
return tensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getShape()[0];

Check warning on line 188 in src/main/java/org/opensearch/agent/indices/IndicesHelper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/indices/IndicesHelper.java#L186-L188

Added lines #L186 - L188 were not covered by tests
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void run() {
.map(sample -> (String) sample.get(INDEX_SUMMARY))
.collect(Collectors.toList());

Check warning on line 118 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L115-L118

Added lines #L115 - L118 were not covered by tests

List<ModelTensors> mlModelOutputs = mlClients.getEmbeddingResult(modelId, embeddingDocs, mlTaskResponse -> {
List<ModelTensors> mlModelOutputs = mlClients.getEmbeddingResult(modelId, embeddingDocs, true, mlTaskResponse -> {
ModelTensorOutput output = (ModelTensorOutput) mlTaskResponse.getOutput();
return output.getMlModelOutputs();

Check warning on line 122 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L120-L122

Added lines #L120 - L122 were not covered by tests
});
Expand Down Expand Up @@ -186,12 +186,12 @@ private List<Map<String, Object>> getAllIndexMappingAndSampleData() {
indexSummaryMap.put(INDEX_PATTERNS, indexPatterns);
indexSummaryMap.put(ALIASES, indexMetadata.getAliases().keySet());

Check warning on line 187 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L182-L187

Added lines #L182 - L187 were not covered by tests

Map<String, Object> sourceAsMap = indexMetadata.mapping().getSourceAsMap();
// if index don't have any mapping, ignore
if (sourceAsMap == null || sourceAsMap.isEmpty()) {
// if index have no mapping at all
if (indexMetadata.mapping() == null) {
log.debug("No mapping for index {}", indexName);
continue;

Check warning on line 192 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L191-L192

Added lines #L191 - L192 were not covered by tests
}
Map<String, Object> sourceAsMap = indexMetadata.mapping().getSourceAsMap();
try (XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON)) {
builder.map(sourceAsMap);
String mapping = builder.toString();

Check warning on line 197 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L194-L197

Added lines #L194 - L197 were not covered by tests
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/agent/job/MLClients.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
Expand All @@ -45,6 +46,9 @@
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.search.SearchHit;
Expand All @@ -65,7 +69,7 @@ public MLClients(Client client, NamedXContentRegistry xContentRegistry) {
this.xContentRegistry = xContentRegistry;
}

Check warning on line 70 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L67-L70

Added lines #L67 - L70 were not covered by tests

public <T> T getEmbeddingResult(String modelId, List<String> texts, Function<MLTaskResponse, T> parser) {
public <T> T getEmbeddingResult(String modelId, List<String> texts, boolean deploy, Function<MLTaskResponse, T> parser) {
try {
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet
.builder()
Expand All @@ -80,8 +84,21 @@ public <T> T getEmbeddingResult(String modelId, List<String> texts, Function<MLT
MLTaskResponse mlTaskResponse = predictFuture.get(DEFAULT_TIMEOUT_SECOND, TimeUnit.SECONDS);
return parser.apply(mlTaskResponse);
} catch (Exception ex) {

Check warning on line 86 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L83-L86

Added lines #L83 - L86 were not covered by tests
if (deploy && ExceptionsHelper.stackTrace(ex).contains("Model not ready yet.")) {
log.info("Model {} has not deployed yet, try to deploy", modelId);
ActionFuture<MLDeployModelResponse> deployFuture = client
.execute(MLDeployModelAction.INSTANCE, new MLDeployModelRequest(modelId, false));

Check warning on line 90 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L88-L90

Added lines #L88 - L90 were not covered by tests
try {
MLDeployModelResponse mlDeployModelResponse = deployFuture.get(2, TimeUnit.MINUTES);

Check warning on line 92 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L92

Added line #L92 was not covered by tests
if (mlDeployModelResponse.getStatus().equals(MLTaskState.COMPLETED.name())) {
return getEmbeddingResult(modelId, texts, false, parser);

Check warning on line 94 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L94

Added line #L94 was not covered by tests
}
} catch (Exception e) {
throw ExceptionsHelper.convertToRuntime(e);
}

Check warning on line 98 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L96-L98

Added lines #L96 - L98 were not covered by tests
}
log.error("Invoke ML embedding failed", ex);
throw new RuntimeException(ex);
throw ExceptionsHelper.convertToRuntime(ex);

Check warning on line 101 in src/main/java/org/opensearch/agent/job/MLClients.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/MLClients.java#L100-L101

Added lines #L100 - L101 were not covered by tests
}
}

Expand Down
18 changes: 11 additions & 7 deletions src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

package org.opensearch.agent.tools;

import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.DATA_STREAM;
import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.INDEX_NAME;
import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.INDEX_PATTERNS;
import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.DATA_STREAM_FIELD;
import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.INDEX_NAME_FIELD;
import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.INDEX_PATTERNS_FIELD;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -53,7 +53,7 @@ public class IndexRoutingTool extends VectorDBTool {
public static final String TYPE = "IndexRoutingTool";

private static final String DEFAULT_DESCRIPTION = "Use this tool to select an appropriate index for user question, "
+ "This tool take user plain input and return list of most related indexes or `Not sure`. "
+ "It takes 1 argument which is a string of user question and return list of most related indexes or `Not sure`. "
+ "If the tool returns `Not sure`, mark it as final answer and ask Human to provide index name";

public static final int DEFAULT_K = 5;
Expand Down Expand Up @@ -127,6 +127,7 @@ protected Parser<SearchResponse, Object> searchResponseParser() {

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
log.debug("input={}", parameters.get(INPUT_FIELD));

Check warning on line 130 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L130

Added line #L130 was not covered by tests
// get index of knn-index
super.run(parameters, ActionListener.wrap(res -> {
List<Map<String, Object>> summaries = (List<Map<String, Object>>) res;

Check warning on line 133 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L132-L133

Added lines #L132 - L133 were not covered by tests
Expand All @@ -153,11 +154,13 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
// call LLM, MLModelTool
String question = parameters.get(INPUT_FIELD);
String prompt = buildFinalPrompt(summaryStr, question);
log.debug("prompt send to inference is {}", prompt);

Check warning on line 157 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L155-L157

Added lines #L155 - L157 were not covered by tests
// TODO use MLModelTool
mlClients.inference(inferenceModelId, prompt, ActionListener.wrap(r -> {
ModelTensorOutput output = (ModelTensorOutput) r.getOutput();
ModelTensor modelTensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0);
String response = (String) modelTensor.getDataAsMap().get("response");
log.debug("response back from inference mode is {}", response);
Set<String> validIndexes = findMatchedIndex(response, summaries);

Check warning on line 164 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L159-L164

Added lines #L159 - L164 were not covered by tests
listener.onResponse((T) (validIndexes.isEmpty() ? "Not sure" : validIndexes.iterator().next()));
}, exception -> { listener.onResponse((T) "Not sure"); }));
Expand All @@ -172,17 +175,18 @@ private Set<String> findMatchedIndex(String result, List<Map<String, Object>> ca

Map<String, Map<String, Object>> candidateIndexMap = candidates
.stream()
.collect(Collectors.toMap(m -> (String) m.get(INDEX_NAME), m -> m));
.collect(Collectors.toMap(m -> (String) m.get(INDEX_NAME_FIELD), m -> m));

Check warning on line 178 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L176-L178

Added lines #L176 - L178 were not covered by tests

Set<String> allCandidates = candidateIndexMap.keySet();
List<String> predictedIndexes = Arrays.stream(result.split(",")).map(String::trim).collect(Collectors.toList());
log.debug("all candidates are {}, predictedIndexes are {}", allCandidates, predictedIndexes);

Check warning on line 182 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L180-L182

Added lines #L180 - L182 were not covered by tests

for (String predictedIndex : predictedIndexes) {
if (allCandidates.contains(predictedIndex)) {
Map<String, Object> map = candidateIndexMap.get(predictedIndex);

Check warning on line 186 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L186

Added line #L186 was not covered by tests
// data stream back index
Optional<Object> dataStreamName = Optional.ofNullable(map.get(DATA_STREAM));
List<String> patterns = (List<String>) map.get(INDEX_PATTERNS);
Optional<Object> dataStreamName = Optional.ofNullable(map.get(DATA_STREAM_FIELD));
List<String> patterns = (List<String>) map.get(INDEX_PATTERNS_FIELD);
String indexPattern = getIndexPattern(patterns, predictedIndex);
validIndexes.add((String) dataStreamName.orElse(indexPattern));

Check warning on line 191 in src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java#L188-L191

Added lines #L188 - L191 were not covered by tests
} else if (predictedIndex.equals("Not sure")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public enum LLMProvider {
OPENAI("${prompt}"),
ANTHROPIC("Human: ${prompt}\\nAssistant:"),
ANTHROPIC("\\n\\nHuman: ${prompt} \\n\\nAssistant:"),
MISTRAL("<s>[INST] ${prompt} [/INST]"),
NONE("${prompt}");

Check warning on line 22 in src/main/java/org/opensearch/agent/tools/utils/LLMProvider.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/tools/utils/LLMProvider.java#L19-L22

Added lines #L19 - L22 were not covered by tests

Expand Down

0 comments on commit 985f1d9

Please sign in to comment.