diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index bd018698..71fc77b5 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -143,14 +143,18 @@ public static ModelType from(String value) { * @param client the OpenSearch transport client * @param modelId the model ID of LLM */ - public CreateAnomalyDetectorTool(Client client, String modelId, String modelType) { + public CreateAnomalyDetectorTool(Client client, String modelId, String modelType, String contextPrompt) { this.client = client; this.modelId = modelId; if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { throw new IllegalArgumentException("Unsupported model_type: " + modelType); } this.modelType = ModelType.from(modelType); - this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), ""); + if (contextPrompt.isEmpty()) { + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), ""); + } else { + this.contextPrompt = contextPrompt; + } } /** @@ -432,7 +436,8 @@ public CreateAnomalyDetectorTool create(Map map) { if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { throw new IllegalArgumentException("Unsupported model_type: " + modelType); } - return new CreateAnomalyDetectorTool(client, modelId, modelType); + String prompt = (String) map.getOrDefault("prompt", ""); + return new CreateAnomalyDetectorTool(client, modelId, modelType, prompt); } @Override diff --git a/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json index 9b69bce7..f2169219 100644 --- a/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json +++ b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json @@ -1,4 +1,4 @@ { - "CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"", - "OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. " + "CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"", + "OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}." } diff --git a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java index 0749ab70..322ac4ae 100644 --- a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java @@ -243,6 +243,23 @@ public void testToolWithPredictModelFailed() { })); } + @Test + public void testToolWithCustomPrompt() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + assertEquals("custom prompt", tool.getContextPrompt()); + + tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + } + private void createMappings() { indexMappings = new HashMap<>(); indexMappings diff --git a/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java index 648a381b..e524211d 100644 --- a/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java +++ b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java @@ -340,6 +340,25 @@ private String registerAgent() { ) ); registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + registerAgentRequestBody = registerAgentRequestBody + .replace( + "", + "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, " + + " {\\\"time_field\\\":\\\"timestamp\\\",\\\"indices\\\":[\\\"server_log*\\\"],\\\"feature_attributes\\\":" + + "[{\\\"feature_name\\\":\\\"test\\\",\\\"feature_enabled\\\":true," + + "\\\"aggregation_query\\\":{\\\"test\\\":{\\\"sum\\\":{\\\"field\\\":\\\"value\\\"}}}}],\\\"category_field\\\":[\\\"ip\\\"]}," + + " and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, " + + "and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about " + + "creating an anomaly detector for the index ${indexInfo.indexName}, " + + "you need to give the key information: the top 3 suitable aggregation fields which are numeric types and " + + "the suitable aggregation method for each field, " + + "if there are no numeric type fields, both the aggregation field and method are empty string, " + + " and also give the category field if there exists a keyword type field like ip, address, host, city, country or region," + + " if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list " + + "wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited" + + " list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. " + ); + return createAgent(registerAgentRequestBody); } } diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json index 3ad9477e..cfab9ba7 100644 --- a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json @@ -5,7 +5,8 @@ { "type": "CreateAnomalyDetectorTool", "parameters": { - "model_id": "" + "model_id": "", + "prompt": "" } } ]