Skip to content

Commit

Permalink
populate time fields for connectors on return (opensearch-project#2922)
Browse files Browse the repository at this point in the history
* populate time fields for connectors on return

fixes opensearch-project#2890 Currently any class that extends the AbstractConnector class has the fields createdTime and lastUpdatedTime set to null. The solution was instantiating the fields in the constructor of the AbstractConnector class, as well updating it within the HTTPConnector class whenever an update happens. Many tests were modified to catch the time fields being populated as such there will be many differences on the string in order to get around the timing issue when doing tests.

Signed-off-by: Brian Flores <iflorbri@amazon.com>

* fixes backward compatability issues with old connectors

fixes opensearch-project#2890 when applying a code change like this previous connectors would have a weird bug where upon calling GET on them would change the timestamp. In this commit, it remains the old connectors without time fields while new ones have time fields, newer connectors will have correct and updated Timestamp. Manual testing was done on a local cluster and two unit tests were done to inspect the time changes on creation and update

Signed-off-by: Brian Flores <iflorbri@amazon.com>

* fix failing MLRegisterModelInutTest.testToXContent tests

Originally this commit was cherry picked from the 2.x branch and as such code changes affected the new build that werent caught on the previous commit 8c006de. Reformatted tests that were failing as the behavior implemented in previous commits was to not display time fields if a connector does not have them in the first place. gradlew build was done to assure the tests passed

Signed-off-by: Brian Flores <iflorbri@amazon.com>

* Reverts back model tests that were modified incorrectly by connector change

When creating a code change to the connector it propagated the new change of the object that affected many UTs, but after changing the logic of indexing the new connector, change the old changes for the unit test involving models with connectors had to be reverted back. UTs specifically for the indexed connectors have been created in UpdateConnectorTransportActionTests were done to capture this

Signed-off-by: Brian Flores <iflorbri@amazon.com>

* Adds lastUpdateTime to Old Connectors

Previoulsy we didnt consider the old connectors to have time fields at all, But given offline discussion if we add time fields to old connectors users could get more information moving forward without breaking any backward features. The solution to this was setting the last updated time in the update connector api; now moving forward any connector gets attached a last updated time field. I updated the testUpdateConnectorDoesNotUpdateHTTPCOnnectorTimeFields method to check that lastUpdateTime has a timestamp but that createdTime has no time field.

Signed-off-by: Brian Flores <iflorbri@amazon.com>

* Fixes wildcard import in UpdateConnectorTransportActionTests

Signed-off-by: Brian Flores <iflorbri@amazon.com>

---------

Signed-off-by: Brian Flores <iflorbri@amazon.com>
  • Loading branch information
brianf-aws authored Oct 1, 2024
1 parent f4b4724 commit 88eaefd
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public abstract class AbstractConnector implements Connector {
protected User owner;
@Setter
protected AccessMode access;
@Setter
protected Instant createdTime;
@Setter
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -44,6 +45,10 @@ public interface Connector extends ToXContentObject, Writeable {

String getProtocol();

void setCreatedTime(Instant createdTime);

void setLastUpdateTime(Instant lastUpdateTime);

User getOwner();

void setOwner(User user);
Expand Down
32 changes: 16 additions & 16 deletions common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ public void toXContent_InternalConnector() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent
);

String expectedConnectorResponse = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}";

assertEquals(expectedConnectorResponse, mlModelContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
}

@Test
public void constructor_Parser() throws IOException {

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,21 @@ public void toXContentTest() throws IOException {
mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(
"{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}",
jsonStr
);

String expectedControllerResponse = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

assertEquals(expectedControllerResponse, jsonStr);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public class MLUpdateModelInputTest {
+ "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
+ "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
+ "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+ "\"test-connector_id\",\"last_updated_time\":1}";
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},"
+ "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}";

private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
+ "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":"
Expand Down Expand Up @@ -152,6 +152,7 @@ public void readInputStreamSuccessWithNullFields() throws IOException {
@Test
public void testToXContent() throws Exception {
String jsonStr = serializationWithToXContent(updateModelInput);

assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,40 @@ public void testToXContent() throws Exception {
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(expectedInputStr, jsonStr);

String expectedFunctionInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\","
+ "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\","
+ "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\","
+ "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,"
+ "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\","
+ "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"],"
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}";

assertEquals(expectedFunctionInputStr, jsonStr);
}

@Test
public void testToXContent_Incomplete() throws Exception {
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
input.setModelNodeIds(null);
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);

String jsonStr = builder.toString();

String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\","
+ "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\","
+ "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\","
Expand All @@ -186,14 +215,7 @@ public void testToXContent_Incomplete() throws Exception {
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}";
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
input.setModelNodeIds(null);
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();

assertEquals(expectedIncompleteInputStr, jsonStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.HashSet;
import java.util.List;

Expand Down Expand Up @@ -134,6 +135,10 @@ private void indexConnector(Connector connector, ActionListener<MLCreateConnecto
listener.onResponse(response);
}, listener::onFailure);

Instant currentTime = Instant.now();
connector.setCreatedTime(currentTime);
connector.setLastUpdateTime(currentTime);

IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX);
indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -93,6 +94,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
if (Boolean.TRUE.equals(hasPermission)) {
connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt);
connector.validateConnectorURL(trustedConnectorEndpointsRegex);

connector.setLastUpdateTime(Instant.now());

UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -202,6 +203,117 @@ public void setup() throws IOException {
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));
}

@Test
public void testUpdateConnectorDoesNotUpdateHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

assertNull(connector.getCreatedTime());
assertNull(connector.getLastUpdateTime());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertNull(connector.getCreatedTime());
assertNotNull(connector.getLastUpdateTime());
}

@Test
public void testUpdateConnectorUpdatesHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

Instant testInitialTime = Instant.now();
connector.setCreatedTime(testInitialTime);
connector.setLastUpdateTime(testInitialTime);

assert (connector.getCreatedTime().toEpochMilli() == connector.getLastUpdateTime().toEpochMilli());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertTrue(
"Last update time must be bigger than the creation time",
connector.getLastUpdateTime().toEpochMilli() >= connector.getCreatedTime().toEpochMilli()
);
}

@Test
public void testExecuteConnectorAccessControlSuccess() {
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));
Expand Down

0 comments on commit 88eaefd

Please sign in to comment.