diff --git a/src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java b/src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java index 30141377..50f1e124 100644 --- a/src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java +++ b/src/main/java/org/opensearch/agent/tools/IndexRoutingTool.java @@ -50,9 +50,49 @@ public class IndexRoutingTool extends VectorDBTool { + "This tool take user plain original question as input and return list of most related indexes or `Not sure`. " + "If the tool returns `Not sure`, mark it as final answer and ask Human to provide index name"; + private static final String DEFAULT_PROMPT_TEMPLATE = String + .format( + Locale.ROOT, + "Human: %s\\nAssistant:", + "You are an experienced engineer in OpenSearch and ElasticSearch. \n" + + "\n" + + "Given a question, your task is to choose the relevant indexes from a list of indexes.\n" + + "\n" + + "For every index, you will be given the index mapping, followed by sample data from the index.\n" + + "\n" + + "The data format is like:\n" + + "\n" + + "index-1: Index Mappings:\n" + + "mappings of index-1\n" + + "Sample data:\n" + + "data from index-1\n" + + "---\n" + + "index-2: Index Mappings:\n" + + "mappings of index-2\n" + + "Sample data:\n" + + "data from index-2\n" + + "---\n" + + "...\n" + + "\n" + + "Now the actual index mappings and sample data begins:\n" + + "{summaries}\n" + + "\n" + + "-------------------\n" + + "\n" + + "Format the output as a comma-separated sequence, e.g. index-1, index-2, index-3. If no indexes \n" + + "appear relevant to the question, return the empty string ''.\n" + + "\n" + + "Just return the index names, nothing else. \n" + + "If you are not sure, just return 'Not sure', nothing else.\n" + + "\n" + + "Question: {question}\n" + + "Answer:" + ); + public static final int DEFAULT_K = 5; public static String EMBEDDING_MODEL_ID = "embedding_model_id"; public static String INFERENCE_MODEL_ID = "inference_model_id"; + public static String PROMPT_TEMPLATE = "prompt_template"; @Setter @Getter @@ -67,6 +107,9 @@ public class IndexRoutingTool extends VectorDBTool { private final MLClients mlClients; + @Setter + private String promptTemplate; + public IndexRoutingTool( Client client, NamedXContentRegistry xContentRegistry, @@ -208,43 +251,8 @@ private Optional findWithSimilarity(String predictedIndex, Collection params = Map.of("question", question, "summaries", summaryString); - return new StringSubstitutor(params).replace(defaultTemplate); + return new StringSubstitutor(params).replace(promptTemplate); } @Override @@ -280,14 +288,28 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) { public IndexRoutingTool create(Map params) { String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID); String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID); + String promptTemplate = params.get(PROMPT_TEMPLATE) == null ? DEFAULT_PROMPT_TEMPLATE : (String) params.get(PROMPT_TEMPLATE); Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_K; Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K; - return new IndexRoutingTool(client, xContentRegistry, docSize, k, embeddingModelId, inferenceModelId); + IndexRoutingTool tool = new IndexRoutingTool(client, xContentRegistry, docSize, k, embeddingModelId, inferenceModelId); + tool.setPromptTemplate(promptTemplate); + + return tool; } @Override public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } } }