Skip to content

Commit

Permalink
Base class for Integ test; add integ test for NeuralSparseSearchTool (#…
Browse files Browse the repository at this point in the history
…86)

* add common components

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add common components

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add basic components

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rebase main

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add basic components for it, add it

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rebase main

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change neural sparse model to pretrained tokenizer

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rm redundant line

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add comments

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add register connector

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

---------

Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws authored Jan 2, 2024
1 parent 051b79a commit 34ae75f
Show file tree
Hide file tree
Showing 6 changed files with 631 additions and 0 deletions.
22 changes: 22 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ dependencies {
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${version}"
zipArchive "org.opensearch.plugin:opensearch-anomaly-detection:${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${version}"

// Test dependencies
testImplementation "org.opensearch.test:framework:${opensearch_version}"
Expand Down Expand Up @@ -348,6 +350,26 @@ testClusters.integTest {
}
}

// Remote Integration Tests
task integTestRemote(type: RestIntegTestTask) {
testClassesDirs = sourceSets.test.output.classesDirs
classpath = sourceSets.test.runtimeClasspath

systemProperty "https", System.getProperty("https")
systemProperty "user", System.getProperty("user")
systemProperty "password", System.getProperty("password")

systemProperty 'cluster.number_of_nodes', "${_numNodes}"

systemProperty 'tests.security.manager', 'false'
// Run tests with remote cluster only if rest case is defined
if (System.getProperty("tests.rest.cluster") != null) {
filter {
includeTestsMatching "org.opensearch.integTest.*IT"
}
}
}

// Automatically sets up the integration test cluster locally
run {
useCluster testClusters.integTest
Expand Down
241 changes: 241 additions & 0 deletions src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.integTest;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.apache.commons.lang3.StringUtils;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.*;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;

import com.google.common.collect.ImmutableList;
import com.google.gson.Gson;

import lombok.SneakyThrows;

public abstract class BaseAgentToolsIT extends OpenSearchSecureRestTestCase {
public static final Gson gson = new Gson();
private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

/**
* Update cluster settings to run ml models
*/
@Before
public void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
}

@SneakyThrows
protected void updateClusterSettings(String settingKey, Object value) {
XContentBuilder builder = XContentFactory
.jsonBuilder()
.startObject()
.startObject("persistent")
.field(settingKey, value)
.endObject()
.endObject();
Response response = makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
builder.toString(),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);

assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

@SneakyThrows
private Map parseResponseToMap(Response response) {
Map<String, Object> responseInMap = XContentHelper
.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity()), false);
response.getEntity().toString();
return responseInMap;
}

@SneakyThrows
private Object parseFieldFromResponse(Response response, String field) {
assertNotNull(field);
Map map = parseResponseToMap(response);
Object result = map.get(field);
assertNotNull(result);
return result;
}

protected String createConnector(String requestBody) {
Response response = makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, requestBody, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
return parseFieldFromResponse(response, MLModel.CONNECTOR_ID_FIELD).toString();
}

protected String registerModel(String requestBody) {
Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, requestBody, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString();
}

protected String deployModel(String modelId) {
Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, (String) null, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString();
}

@SneakyThrows
protected Response waitTaskComplete(String taskId) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, (String) null, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String state = parseFieldFromResponse(response, MLTask.STATE_FIELD).toString();
if (state.equals(MLTaskState.COMPLETED.toString())) {
return response;
}
if (state.equals(MLTaskState.FAILED.toString())
|| state.equals(MLTaskState.CANCELLED.toString())
|| state.equals(MLTaskState.COMPLETED_WITH_ERROR.toString())) {
fail("The task failed with state " + state);
}
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The task failed to complete after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

// Register the model then deploy it. Returns the model_id until the model is deployed
protected String registerModelThenDeploy(String requestBody) {
String registerModelTaskId = registerModel(requestBody);
Response registerTaskResponse = waitTaskComplete(registerModelTaskId);
String modelId = parseFieldFromResponse(registerTaskResponse, MLTask.MODEL_ID_FIELD).toString();
String deployModelTaskId = deployModel(modelId);
waitTaskComplete(deployModelTaskId);
return modelId;
}

protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception {
Response response = makeRequest(client(), "PUT", indexName, null, indexConfiguration, null);
assertEquals("true", parseFieldFromResponse(response, "acknowledged").toString());
assertEquals(indexName, parseFieldFromResponse(response, "index").toString());
}

@SneakyThrows
protected void addDocToIndex(String indexName, String docId, List<String> fieldNames, List<Object> fieldContents) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
for (int i = 0; i < fieldNames.size(); i++) {
builder.field(fieldNames.get(i), fieldContents.get(i));
}
builder.endObject();
Response response = makeRequest(
client(),
"POST",
"/" + indexName + "/_doc/" + docId + "?refresh=true",
null,
builder.toString(),
null
);
assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

public String createAgent(String requestBody) {
Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, requestBody, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
return parseFieldFromResponse(response, AgentMLInput.AGENT_ID_FIELD).toString();
}

private String parseStringResponseFromExecuteAgentResponse(Response response) {
Map responseInMap = parseResponseToMap(response);
Optional<String> optionalResult = Optional
.ofNullable(responseInMap)
.map(m -> (List) m.get(ModelTensorOutput.INFERENCE_RESULT_FIELD))
.map(l -> (Map) l.get(0))
.map(m -> (List) m.get(ModelTensors.OUTPUT_FIELD))
.map(l -> (Map) l.get(0))
.map(m -> (String) (m.get(ModelTensor.RESULT_FIELD)));
return optionalResult.get();
}

// execute the agent, and return the String response from the json structure
// {"inference_results": [{"output": [{"name": "response","result": "the result to return."}]}]}
public String executeAgent(String agentId, String requestBody) {
Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, requestBody, null);
return parseStringResponseFromExecuteAgentResponse(response);
}

public static Response makeRequest(
RestClient client,
String method,
String endpoint,
Map<String, String> params,
String jsonEntity,
List<Header> headers
) {
HttpEntity httpEntity = StringUtils.isBlank(jsonEntity) ? null : new StringEntity(jsonEntity, ContentType.APPLICATION_JSON);
return makeRequest(client, method, endpoint, params, httpEntity, headers);
}

public static Response makeRequest(
RestClient client,
String method,
String endpoint,
Map<String, String> params,
HttpEntity entity,
List<Header> headers
) {
return makeRequest(client, method, endpoint, params, entity, headers, false);
}

@SneakyThrows
public static Response makeRequest(
RestClient client,
String method,
String endpoint,
Map<String, String> params,
HttpEntity entity,
List<Header> headers,
boolean strictDeprecationMode
) {
Request request = new Request(method, endpoint);

RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
if (headers != null) {
headers.forEach(header -> options.addHeader(header.getName(), header.getValue()));
}
options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE);
request.setOptions(options.build());

if (params != null) {
params.forEach(request::addParameter);
}
if (entity != null) {
request.setEntity(entity);
}
return client.performRequest(request);
}
}
Loading

0 comments on commit 34ae75f

Please sign in to comment.