Skip to content

Commit

Permalink
Optimize the default prompt and make prompt customizable for create a…
Browse files Browse the repository at this point in the history
…nomaly detector tool (#399)

* Optimize the prompt for create anomaly detector tool

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* Remove whitespace

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* Make prompt for CreateAnomalyDetectorToll customized

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* format the code

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* Fix test failure

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* fix test failure

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* Format the code

Signed-off-by: gaobinlong <gbinlong@amazon.com>

* Add more tests

Signed-off-by: gaobinlong <gbinlong@amazon.com>

---------

Signed-off-by: gaobinlong <gbinlong@amazon.com>
(cherry picked from commit 06a8537)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Sep 29, 2024
1 parent a26c924 commit a721189
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,18 @@ public static ModelType from(String value) {
* @param client the OpenSearch transport client
* @param modelId the model ID of LLM
*/
public CreateAnomalyDetectorTool(Client client, String modelId, String modelType) {
public CreateAnomalyDetectorTool(Client client, String modelId, String modelType, String contextPrompt) {
this.client = client;
this.modelId = modelId;
if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
throw new IllegalArgumentException("Unsupported model_type: " + modelType);
}
this.modelType = ModelType.from(modelType);
this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), "");
if (contextPrompt.isEmpty()) {
this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), "");
} else {
this.contextPrompt = contextPrompt;
}
}

/**
Expand Down Expand Up @@ -432,7 +436,8 @@ public CreateAnomalyDetectorTool create(Map<String, Object> map) {
if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
throw new IllegalArgumentException("Unsupported model_type: " + modelType);
}
return new CreateAnomalyDetectorTool(client, modelId, modelType);
String prompt = (String) map.getOrDefault("prompt", "");
return new CreateAnomalyDetectorTool(client, modelId, modelType, prompt);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}."
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ public void testToolWithPredictModelFailed() {
}));
}

@Test
public void testToolWithCustomPrompt() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt"));
assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("CLAUDE", tool.getModelType().toString());
assertEquals("custom prompt", tool.getContextPrompt());

tool
.run(
ImmutableMap.of("index", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(mockedResult, response), log::info)
);
}

private void createMappings() {
indexMappings = new HashMap<>();
indexMappings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,25 @@ private String registerAgent() {
)
);
registerAgentRequestBody = registerAgentRequestBody.replace("<MODEL_ID>", modelId);
registerAgentRequestBody = registerAgentRequestBody
.replace(
"<CUSTOM_PROMPT>",
"Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, "
+ " {\\\"time_field\\\":\\\"timestamp\\\",\\\"indices\\\":[\\\"server_log*\\\"],\\\"feature_attributes\\\":"
+ "[{\\\"feature_name\\\":\\\"test\\\",\\\"feature_enabled\\\":true,"
+ "\\\"aggregation_query\\\":{\\\"test\\\":{\\\"sum\\\":{\\\"field\\\":\\\"value\\\"}}}}],\\\"category_field\\\":[\\\"ip\\\"]},"
+ " and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, "
+ "and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about "
+ "creating an anomaly detector for the index ${indexInfo.indexName}, "
+ "you need to give the key information: the top 3 suitable aggregation fields which are numeric types and "
+ "the suitable aggregation method for each field, "
+ "if there are no numeric type fields, both the aggregation field and method are empty string, "
+ " and also give the category field if there exists a keyword type field like ip, address, host, city, country or region,"
+ " if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list "
+ "wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited"
+ " list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
);

return createAgent(registerAgentRequestBody);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
{
"type": "CreateAnomalyDetectorTool",
"parameters": {
"model_id": "<MODEL_ID>"
"model_id": "<MODEL_ID>",
"prompt": "<CUSTOM_PROMPT>"
}
}
]
Expand Down

0 comments on commit a721189

Please sign in to comment.