From 88eaefdcfc2d1e37b8ec7bffd0e730840325a3f2 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Mon, 30 Sep 2024 17:36:25 -0700 Subject: [PATCH] populate time fields for connectors on return (#2922) * populate time fields for connectors on return fixes #2890 Currently any class that extends the AbstractConnector class has the fields createdTime and lastUpdatedTime set to null. The solution was instantiating the fields in the constructor of the AbstractConnector class, as well updating it within the HTTPConnector class whenever an update happens. Many tests were modified to catch the time fields being populated as such there will be many differences on the string in order to get around the timing issue when doing tests. Signed-off-by: Brian Flores * fixes backward compatability issues with old connectors fixes #2890 when applying a code change like this previous connectors would have a weird bug where upon calling GET on them would change the timestamp. In this commit, it remains the old connectors without time fields while new ones have time fields, newer connectors will have correct and updated Timestamp. Manual testing was done on a local cluster and two unit tests were done to inspect the time changes on creation and update Signed-off-by: Brian Flores * fix failing MLRegisterModelInutTest.testToXContent tests Originally this commit was cherry picked from the 2.x branch and as such code changes affected the new build that werent caught on the previous commit 8c006de. Reformatted tests that were failing as the behavior implemented in previous commits was to not display time fields if a connector does not have them in the first place. gradlew build was done to assure the tests passed Signed-off-by: Brian Flores * Reverts back model tests that were modified incorrectly by connector change When creating a code change to the connector it propagated the new change of the object that affected many UTs, but after changing the logic of indexing the new connector, change the old changes for the unit test involving models with connectors had to be reverted back. UTs specifically for the indexed connectors have been created in UpdateConnectorTransportActionTests were done to capture this Signed-off-by: Brian Flores * Adds lastUpdateTime to Old Connectors Previoulsy we didnt consider the old connectors to have time fields at all, But given offline discussion if we add time fields to old connectors users could get more information moving forward without breaking any backward features. The solution to this was setting the last updated time in the update connector api; now moving forward any connector gets attached a last updated time field. I updated the testUpdateConnectorDoesNotUpdateHTTPCOnnectorTimeFields method to check that lastUpdateTime has a timestamp but that createdTime has no time field. Signed-off-by: Brian Flores * Fixes wildcard import in UpdateConnectorTransportActionTests Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores --- .../common/connector/AbstractConnector.java | 2 + .../ml/common/connector/Connector.java | 5 + .../ml/common/RemoteModelTests.java | 32 ++--- .../common/connector/HttpConnectorTest.java | 2 +- .../MLConnectorGetResponseTests.java | 30 ++--- .../model/MLUpdateModelInputTest.java | 5 +- .../register/MLRegisterModelInputTest.java | 40 +++++-- .../TransportCreateConnectorAction.java | 5 + .../UpdateConnectorTransportAction.java | 4 + .../UpdateConnectorTransportActionTests.java | 112 ++++++++++++++++++ 10 files changed, 194 insertions(+), 43 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index aac9a1acad..4849f79c93 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -65,7 +65,9 @@ public abstract class AbstractConnector implements Connector { protected User owner; @Setter protected AccessMode access; + @Setter protected Instant createdTime; + @Setter protected Instant lastUpdateTime; @Setter protected ConnectorClientConfig connectorClientConfig; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index c808f6628c..0a37641144 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -12,6 +12,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.time.Instant; import java.util.List; import java.util.Map; import java.util.Optional; @@ -44,6 +45,10 @@ public interface Connector extends ToXContentObject, Writeable { String getProtocol(); + void setCreatedTime(Instant createdTime); + + void setLastUpdateTime(Instant lastUpdateTime); + User getOwner(); void setOwner(User user); diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java index acb05c9b66..b09553e7d9 100644 --- a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -61,22 +61,22 @@ public void toXContent_InternalConnector() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals( - "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," - + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," - + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," - + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," - + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}", - mlModelContent - ); + + String expectedConnectorResponse = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," + + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}"; + + assertEquals(expectedConnectorResponse, mlModelContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 0115ac1376..16bbc76bfa 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -85,12 +85,12 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); connector.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); } @Test public void constructor_Parser() throws IOException { - XContentParser parser = XContentType.JSON .xContent() .createParser( diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 8e8b94dac4..936c71da95 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -59,21 +59,21 @@ public void toXContentTest() throws IOException { mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals( - "{\"name\":\"test_connector_name\",\"version\":\"1\"," - + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," - + "\"client_config\":{\"max_connection\":30," - + "\"connection_timeout\":30000,\"read_timeout\":30000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}", - jsonStr - ); + + String expectedControllerResponse = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30," + + "\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + + assertEquals(expectedControllerResponse, jsonStr); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 9014f0ec49..46e89d0aa6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -64,8 +64,8 @@ public class MLUpdateModelInputTest { + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" - + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" - + "\"test-connector_id\",\"last_updated_time\":1}"; + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]}," + + "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}"; private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" @@ -152,6 +152,7 @@ public void readInputStreamSuccessWithNullFields() throws IOException { @Test public void testToXContent() throws Exception { String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 89d887a325..7b919a4ae2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -167,11 +167,40 @@ public void testToXContent() throws Exception { input.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals(expectedInputStr, jsonStr); + + String expectedFunctionInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + + "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\"," + + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\"," + + "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; + + assertEquals(expectedFunctionInputStr, jsonStr); } @Test public void testToXContent_Incomplete() throws Exception { + input.setUrl(null); + input.setModelConfig(null); + input.setModelFormat(null); + input.setModelNodeIds(null); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + + String jsonStr = builder.toString(); + String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\"," + "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\"," + "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\"," @@ -186,14 +215,7 @@ public void testToXContent_Incomplete() throws Exception { + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; - input.setUrl(null); - input.setModelConfig(null); - input.setModelFormat(null); - input.setModelNodeIds(null); - XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); - input.toXContent(builder, ToXContent.EMPTY_PARAMS); - assertNotNull(builder); - String jsonStr = builder.toString(); + assertEquals(expectedIncompleteInputStr, jsonStr); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 4cadcc936a..92b087f686 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import java.time.Instant; import java.util.HashSet; import java.util.List; @@ -134,6 +135,10 @@ private void indexConnector(Connector connector, ActionListener { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); + + assertNull(connector.getCreatedTime()); + assertNotNull(connector.getLastUpdateTime()); + } + + @Test + public void testUpdateConnectorUpdatesHttpConnectorTimeFields() { + HttpConnector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + + Instant testInitialTime = Instant.now(); + connector.setCreatedTime(testInitialTime); + connector.setLastUpdateTime(testInitialTime); + + assert (connector.getCreatedTime().toEpochMilli() == connector.getLastUpdateTime().toEpochMilli()); + + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); + + assertTrue( + "Last update time must be bigger than the creation time", + connector.getLastUpdateTime().toEpochMilli() >= connector.getCreatedTime().toEpochMilli() + ); + } + @Test public void testExecuteConnectorAccessControlSuccess() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));