From 858cd841efb6e68d435a62355c034cf3beec694c Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 4 Aug 2022 10:31:30 -0700 Subject: [PATCH] fix index mapping (#384) (#389) Signed-off-by: Yaliang Wu (cherry picked from commit b9036b1afc819cb8dea15c7bf6259ab5e88a656a) Co-authored-by: Yaliang Wu --- .../client/MachineLearningNodeClientTest.java | 2 +- .../org/opensearch/ml/common/CommonValue.java | 95 ++++++++++ .../org/opensearch/ml/common/MLModel.java | 33 ++-- .../java/org/opensearch/ml/common/MLTask.java | 6 +- .../opensearch/ml/common/MLModelTests.java | 2 +- .../model/MLModelGetResponseTest.java | 7 +- plugin/build.gradle | 3 +- .../models/DeleteModelTransportAction.java | 2 +- .../models/GetModelTransportAction.java | 2 +- .../tasks/DeleteTaskTransportAction.java | 2 +- .../action/tasks/GetTaskTransportAction.java | 2 +- .../org/opensearch/ml/indices/MLIndex.java | 47 +++++ .../ml/indices/MLIndicesHandler.java | 139 ++++++++------ .../ml/plugin/MachineLearningPlugin.java | 4 +- .../ml/rest/RestMLSearchModelAction.java | 2 +- .../ml/rest/RestMLSearchTaskAction.java | 2 +- .../opensearch/ml/rest/RestMLStatsAction.java | 2 +- .../ml/task/MLPredictTaskRunner.java | 2 +- .../org/opensearch/ml/task/MLTaskManager.java | 2 +- .../org/opensearch/ml/task/MLTaskRunner.java | 7 - .../ml/task/MLTrainingTaskRunner.java | 2 +- .../ml/indices/MLIndicesHandlerTests.java | 176 ++++++++++++++++-- .../ml/rest/RestMLSearchModelActionTests.java | 2 +- .../ml/rest/RestMLStatsActionTests.java | 2 +- .../opensearch/ml/utils/IntegTestUtils.java | 5 +- .../org/opensearch/ml/utils/TestHelper.java | 74 ++++++++ 26 files changed, 514 insertions(+), 110 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/CommonValue.java create mode 100644 plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 5fea523828..87c0372291 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -308,7 +308,7 @@ public void searchModel() { verify(client).execute(eq(MLModelSearchAction.INSTANCE), isA(SearchRequest.class), any()); verify(searchModelActionListener).onResponse(argumentCaptor.capture()); Map source = argumentCaptor.getValue().getHits().getAt(0).getSourceAsMap(); - assertEquals(modelContent, source.get(MLModel.MODEL_CONTENT)); + assertEquals(modelContent, source.get(MLModel.MODEL_CONTENT_FIELD)); } @Test diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java new file mode 100644 index 0000000000..6fc7adc1f6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +public class CommonValue { + + public static Integer NO_SCHEMA_VERSION = 0; + public static final String USER = "user"; + public static final String META = "_meta"; + public static final String SCHEMA_VERSION_FIELD = "schema_version"; + + public static final String ML_MODEL_INDEX = ".plugins-ml-model"; + public static final String ML_TASK_INDEX = ".plugins-ml-task"; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 1; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; + public static final String USER_FIELD_MAPPING = " \"" + + CommonValue.USER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " }\n"; + public static final String ML_MODEL_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModel.ALGORITHM_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLModel.MODEL_VERSION_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_FIELD + + "\" : {\"type\": \"binary\"},\n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; + + public static final String ML_TASK_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_TASK_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLTask.MODEL_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.TASK_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.FUNCTION_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.STATE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.INPUT_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.PROGRESS_FIELD + + "\": {\"type\": \"float\"},\n" + + " \"" + + MLTask.OUTPUT_INDEX_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.WORKER_NODE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.LAST_UPDATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.ERROR_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + MLTask.IS_ASYNC_TASK_FIELD + + "\" : {\"type\" : \"boolean\"}, \n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; +} diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 957f8e1da2..bd7fe330ab 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -18,14 +18,15 @@ import java.util.Base64; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; @Getter public class MLModel implements ToXContentObject { - public static final String ALGORITHM = "algorithm"; - public static final String MODEL_NAME = "name"; - public static final String MODEL_VERSION = "version"; - public static final String MODEL_CONTENT = "content"; - public static final String USER = "user"; + public static final String ALGORITHM_FIELD = "algorithm"; + public static final String MODEL_NAME_FIELD = "name"; + public static final String MODEL_VERSION_FIELD = "version"; + public static final String OLD_MODEL_CONTENT_FIELD = "content"; + public static final String MODEL_CONTENT_FIELD = "model_content"; private String name; private FunctionName algorithm; @@ -75,16 +76,16 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); if (name != null) { - builder.field(MODEL_NAME, name); + builder.field(MODEL_NAME_FIELD, name); } if (algorithm != null) { - builder.field(ALGORITHM, algorithm); + builder.field(ALGORITHM_FIELD, algorithm); } if (version != null) { - builder.field(MODEL_VERSION, version); + builder.field(MODEL_VERSION_FIELD, version); } if (content != null) { - builder.field(MODEL_CONTENT, content); + builder.field(MODEL_CONTENT_FIELD, content); } if (user != null) { builder.field(USER, user); @@ -98,6 +99,7 @@ public static MLModel parse(XContentParser parser) throws IOException { FunctionName algorithm = null; Integer version = null; String content = null; + String oldContent = null; User user = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -106,19 +108,22 @@ public static MLModel parse(XContentParser parser) throws IOException { parser.nextToken(); switch (fieldName) { - case MODEL_NAME: + case MODEL_NAME_FIELD: name = parser.text(); break; - case MODEL_CONTENT: + case MODEL_CONTENT_FIELD: content = parser.text(); break; - case MODEL_VERSION: + case OLD_MODEL_CONTENT_FIELD: + oldContent = parser.text(); + break; + case MODEL_VERSION_FIELD: version = parser.intValue(false); break; case USER: user = User.parse(parser); break; - case ALGORITHM: + case ALGORITHM_FIELD: algorithm = FunctionName.from(parser.text()); break; default: @@ -130,7 +135,7 @@ public static MLModel parse(XContentParser parser) throws IOException { .name(name) .algorithm(algorithm) .version(version) - .content(content) + .content(content == null ? oldContent : content) .user(user) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index ab4188a3e5..4bb22193a0 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -22,6 +22,7 @@ import java.time.Instant; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; @Getter @EqualsAndHashCode @@ -39,7 +40,6 @@ public class MLTask implements ToXContentObject, Writeable { public static final String CREATE_TIME_FIELD = "create_time"; public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ERROR_FIELD = "error"; - public static final String USER_FIELD = "user"; public static final String IS_ASYNC_TASK_FIELD = "is_async"; @Setter @@ -177,7 +177,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params builder.field(ERROR_FIELD, error); } if (user != null) { - builder.field(USER_FIELD, user); + builder.field(USER, user); } builder.field(IS_ASYNC_TASK_FIELD, async); return builder.endObject(); @@ -246,7 +246,7 @@ public static MLTask parse(XContentParser parser) throws IOException { case ERROR_FIELD: error = parser.text(); break; - case USER_FIELD: + case USER: user = User.parse(parser); break; case IS_ASYNC_TASK_FIELD: diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java index 4b1f0b6009..7e964379da 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java @@ -40,7 +40,7 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent); + assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"model_content\":\"test_content\"}", mlModelContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java index 36765fbeac..b78d3868d3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.transport.model; import org.junit.Before; @@ -56,7 +61,7 @@ public void toXContentTest() throws IOException { assertEquals("{\"name\":\"model\"," + "\"algorithm\":\"KMEANS\"," + "\"version\":1," + - "\"content\":\"content\"," + + "\"model_content\":\"content\"," + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}}", jsonStr); } } diff --git a/plugin/build.gradle b/plugin/build.gradle index 04a2ae6cea..30ac82942e 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -230,7 +230,8 @@ jacocoTestReport { List jacocoExclusions = [ // TODO: add more unit test to meet the minimal test coverage. 'org.opensearch.ml.constant.CommonValue', - 'org.opensearch.ml.plugin.MachineLearningPlugin*' + 'org.opensearch.ml.plugin.MachineLearningPlugin*', + 'org.opensearch.ml.indices.MLIndicesHandler' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 090dbd0cbd..f8c1a193b6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.models; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 55d2a44f82..0b70d370f3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -6,7 +6,7 @@ package org.opensearch.ml.action.models; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import lombok.AccessLevel; diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java index 717d98dae5..214a950722 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.tasks; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import lombok.extern.log4j.Log4j2; diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 987815356c..f1a9c686a8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -6,7 +6,7 @@ package org.opensearch.ml.action.tasks; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import lombok.extern.log4j.Log4j2; diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java new file mode 100644 index 0000000000..9057c43dd5 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.indices; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_SCHEMA_VERSION; + +public enum MLIndex { + MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), + TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION); + + private final String indexName; + // whether we use an alias for the index + private final boolean alias; + private final String mapping; + private final Integer version; + + MLIndex(String name, boolean alias, String mapping, Integer version) { + this.indexName = name; + this.alias = alias; + this.mapping = mapping; + this.version = version; + } + + public String getIndexName() { + return indexName; + } + + public boolean isAlias() { + return alias; + } + + public String getMapping() { + return mapping; + } + + public Integer getVersion() { + return version; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index f44c0fd1a9..eb7a53247d 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -5,6 +5,15 @@ package org.opensearch.ml.indices; +import static org.opensearch.ml.common.CommonValue.META; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.experimental.FieldDefaults; @@ -13,80 +22,41 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.exception.MLException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor @Log4j2 public class MLIndicesHandler { - public static final String ML_MODEL_INDEX = ".plugins-ml-model"; - public static final String ML_TASK_INDEX = ".plugins-ml-task"; - private static final String ML_MODEL_INDEX_MAPPING = "{\n" - + " \"properties\": {\n" - + " \"task_id\": { \"type\": \"keyword\" },\n" - + " \"algorithm\": {\"type\": \"keyword\"},\n" - + " \"model_name\" : { \"type\": \"keyword\"},\n" - + " \"model_version\" : { \"type\": \"keyword\"},\n" - + " \"model_content\" : { \"type\": \"binary\"}\n" - + " }\n" - + "}"; - - private static final String ML_TASK_INDEX_MAPPING = "{\n" - + " \"properties\": {\n" - + " \"model_id\": {\"type\": \"keyword\"},\n" - + " \"task_type\": {\"type\": \"keyword\"},\n" - + " \"function_name\": {\"type\": \"keyword\"},\n" - + " \"state\": {\"type\": \"keyword\"},\n" - + " \"input_type\": {\"type\": \"keyword\"},\n" - + " \"progress\": {\"type\": \"float\"},\n" - + " \"output_index\": {\"type\": \"keyword\"},\n" - + " \"worker_node\": {\"type\": \"keyword\"},\n" - + " \"create_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"last_update_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"error\": {\"type\": \"text\"},\n" - + " \"user\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; ClusterService clusterService; Client client; - public void initModelIndexIfAbsent() { - initMLIndexIfAbsent(ML_MODEL_INDEX, ML_MODEL_INDEX_MAPPING); - } - - public boolean doesModelIndexExist() { - return clusterService.state().metadata().hasIndex(ML_MODEL_INDEX); - } - - private void initMLIndexIfAbsent(String indexName, String mapping) { - if (!clusterService.state().metadata().hasIndex(indexName)) { - client.admin().indices().prepareCreate(indexName).get(); - log.info("create index:{}", indexName); - } else { - log.info("index:{} is already created", indexName); - } + private static final Map indexMappingUpdated = new HashMap<>(); + static { + indexMappingUpdated.put(ML_MODEL_INDEX, new AtomicBoolean(false)); + indexMappingUpdated.put(ML_TASK_INDEX, new AtomicBoolean(false)); } public void initModelIndexIfAbsent(ActionListener listener) { - initMLIndexIfAbsent(ML_MODEL_INDEX, ML_MODEL_INDEX_MAPPING, listener); + initMLIndexIfAbsent(MLIndex.MODEL, listener); } public void initMLTaskIndex(ActionListener listener) { - initMLIndexIfAbsent(ML_TASK_INDEX, ML_TASK_INDEX_MAPPING, listener); + initMLIndexIfAbsent(MLIndex.TASK, listener); } - public void initMLIndexIfAbsent(String indexName, String mapping, ActionListener listener) { + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { + String indexName = index.getIndexName(); + String mapping = index.getMapping(); + if (!clusterService.state().metadata().hasIndex(indexName)) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener actionListener = ActionListener.wrap(r -> { @@ -108,8 +78,67 @@ public void initMLIndexIfAbsent(String indexName, String mapping, ActionListener } } else { log.info("index:{} is already created", indexName); - listener.onResponse(true); + if (indexMappingUpdated.containsKey(indexName) && !indexMappingUpdated.get(indexName).get()) { + shouldUpdateIndex(indexName, index.getVersion(), ActionListener.wrap(r -> { + if (r) { + // return true if should update index + client + .admin() + .indices() + .putMapping( + new PutMappingRequest().indices(indexName).source(mapping, XContentType.JSON), + ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + indexMappingUpdated.get(indexName).set(true); + listener.onResponse(true); + } else { + listener.onFailure(new MLException("Failed to update index: " + indexName)); + } + }, exception -> { + log.error("Failed to update index " + indexName, exception); + listener.onFailure(exception); + }) + ); + } else { + // no need to update index if it does not exist or the version is already up-to-date. + indexMappingUpdated.get(indexName).set(true); + listener.onResponse(true); + } + }, e -> { + log.error("Failed to update index mapping", e); + listener.onFailure(e); + })); + } else { + // No need to update index if it's not ML system index or it's already updated. + listener.onResponse(true); + } + } + } + + /** + * Check if we should update index based on schema version. + * @param indexName index name + * @param newVersion new index mapping version + * @param listener action listener, if should update index, will pass true to its onResponse method + */ + public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { + IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); + if (indexMetaData == null) { + listener.onResponse(Boolean.FALSE); + return; + } + Integer oldVersion = CommonValue.NO_SCHEMA_VERSION; + Map indexMapping = indexMetaData.mapping().getSourceAsMap(); + Object meta = indexMapping.get(META); + if (meta != null && meta instanceof Map) { + @SuppressWarnings("unchecked") + Map metaMapping = (Map) meta; + Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); + if (schemaVersion instanceof Integer) { + oldVersion = (Integer) schemaVersion; + } } + listener.onResponse(newVersion > oldVersion); } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 826b5afb73..3976883753 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -5,8 +5,8 @@ package org.opensearch.ml.plugin; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import java.util.Collection; import java.util.Collections; diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelAction.java index b0a7c7a2bb..ee9374056c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.rest; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import org.opensearch.ml.common.MLModel; diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java index 20a48b3935..25f69e23b8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.rest; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import org.opensearch.ml.common.MLTask; diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 161c4b7937..d189942503 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -6,7 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import java.io.IOException; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 9b7674a3de..b64fb3fa2a 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -6,7 +6,7 @@ package org.opensearch.ml.task; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index ce0f70eabb..5f93c1d904 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -5,8 +5,8 @@ package org.opensearch.ml.task; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.common.MLTask.LAST_UPDATE_TIME_FIELD; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX; import java.time.Instant; import java.util.HashMap; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index a36765ad9d..9fc23fe9dc 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -40,13 +40,6 @@ public abstract class MLTaskRunner listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); mlIndicesHandler.initMLTaskIndex(listener); } public void testInitMLTaskIndexWithExistingIndex() throws ExecutionException, InterruptedException { - CreateIndexRequest request = new CreateIndexRequest(ML_TASK_INDEX); + CreateIndexRequest request = new CreateIndexRequest(ML_TASK_INDEX).mapping(ML_TASK_INDEX_MAPPING); client.admin().indices().create(request).get(); testInitMLTaskIndex(); } + + public void testInitMLModelIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { + testInitMLIndexIfAbsentWithExistingIndex(ML_MODEL_INDEX, OLD_ML_MODEL_INDEX_MAPPING_V0); + } + + public void testInitMLTaskIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { + testInitMLIndexIfAbsentWithExistingIndex(ML_TASK_INDEX, OLD_ML_TASK_INDEX_MAPPING_V0); + } + + private void testInitMLIndexIfAbsentWithExistingIndex(String indexName, String oldIndexMapping) throws ExecutionException, + InterruptedException, + IOException { + mlIndicesHandler + .shouldUpdateIndex( + indexName, + 1, + ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { throw new RuntimeException(e); }) + ); + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(oldIndexMapping); + client.admin().indices().create(request).get(); + mlIndicesHandler + .shouldUpdateIndex( + indexName, + 1, + ActionListener.wrap(shouldUpdate -> { assertTrue(shouldUpdate); }, e -> { throw new RuntimeException(e); }) + ); + assertNull(getIndexSchemaVersion(indexName)); + ActionListener listener = ActionListener.wrap(r -> { + assertTrue(r); + Integer indexSchemaVersion = getIndexSchemaVersion(indexName); + if (indexSchemaVersion != null) { + assertEquals(1, indexSchemaVersion.intValue()); + mlIndicesHandler + .shouldUpdateIndex( + indexName, + 1, + ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { throw new RuntimeException(e); }) + ); + } + }, e -> { throw new RuntimeException(e); }); + mlIndicesHandler.initModelIndexIfAbsent(listener); + } + + public void testInitMLModelIndexIfAbsentWithNonExistingIndex() { + ActionListener listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); + mlIndicesHandler.initModelIndexIfAbsent(listener); + } + + public void testInitMLModelIndexIfAbsentWithNonExistingIndex_Exception() { + Client mockClient = mock(Client.class); + Object[] objects = setUpMockClient(mockClient); + IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; + MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; + String errorMessage = "test exception"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(adminClient).create(any(), any()); + ActionListener listener = ActionListener + .wrap(r -> { throw new RuntimeException("unexpected result"); }, e -> { assertEquals(errorMessage, e.getMessage()); }); + mlIndicesHandler.initModelIndexIfAbsent(listener); + + when(mockClient.threadPool()).thenThrow(new RuntimeException(errorMessage)); + mlIndicesHandler.initModelIndexIfAbsent(listener); + } + + public void testInitMLModelIndexIfAbsentWithNonExistingIndex_FalseAcknowledge() { + Client mockClient = mock(Client.class); + Object[] objects = setUpMockClient(mockClient); + IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; + MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + CreateIndexResponse response = new CreateIndexResponse(false, false, ML_MODEL_INDEX); + actionListener.onResponse(response); + return null; + }).when(adminClient).create(any(), any()); + ActionListener listener = ActionListener.wrap(r -> { assertFalse(r); }, e -> { throw new RuntimeException(e); }); + mlIndicesHandler.initModelIndexIfAbsent(listener); + } + + private Object[] setUpMockClient(Client mockClient) { + AdminClient admin = spy(client.admin()); + when(mockClient.admin()).thenReturn(admin); + IndicesAdminClient adminClient = spy(client.admin().indices()); + + MLIndicesHandler mlIndicesHandler = new MLIndicesHandler(clusterService, mockClient); + when(admin.indices()).thenReturn(adminClient); + + when(mockClient.threadPool()).thenReturn(client.threadPool()); + + return new Object[] { adminClient, mlIndicesHandler }; + } + + private Integer getIndexSchemaVersion(String indexName) { + IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); + if (indexMetaData == null) { + return null; + } + Integer oldVersion = null; + Map indexMapping = indexMetaData.mapping().getSourceAsMap(); + Object meta = indexMapping.get(META); + if (meta != null && meta instanceof Map) { + Map metaMapping = (Map) meta; + Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); + if (schemaVersion instanceof Integer) { + oldVersion = (Integer) schemaVersion; + } + } + return oldVersion; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java index 51c49b446b..571fbc9e6e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java @@ -12,7 +12,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest; import java.io.IOException; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index 2fb4cc48c6..f6a0e56760 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -13,7 +13,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.getStatsRestRequest; import java.io.IOException; diff --git a/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java b/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java index eb50aa8262..f574f4855c 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java @@ -5,6 +5,8 @@ package org.opensearch.ml.utils; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -52,7 +54,6 @@ import org.opensearch.test.OpenSearchIntegTestCase; public class IntegTestUtils extends OpenSearchIntegTestCase { - public static final String ML_MODEL = ".plugins-ml-model"; public static final String TESTING_DATA = "{\n" + "\"k1\":1.1,\n" + "\"k2\":1.2,\n" @@ -145,7 +146,7 @@ public static SearchResponse waitModelAvailable1(String taskId) throws Interrupt SearchSourceBuilder modelSearchSourceBuilder = new SearchSourceBuilder(); QueryBuilder queryBuilder = QueryBuilders.termQuery("taskId", taskId); modelSearchSourceBuilder.query(queryBuilder); - SearchRequest modelSearchRequest = new SearchRequest(new String[] { ML_MODEL }, modelSearchSourceBuilder); + SearchRequest modelSearchRequest = new SearchRequest(new String[] { ML_MODEL_INDEX }, modelSearchSourceBuilder); SearchResponse modelSearchResponse = null; int i = 0; while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value == 0) && i < 500) { diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 329cff997d..df8b7e4ac3 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -7,12 +7,15 @@ import static org.apache.http.entity.ContentType.APPLICATION_JSON; import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.cluster.node.DiscoveryNodeRole.DATA_ROLE; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -25,16 +28,25 @@ import org.apache.http.entity.StringEntity; import org.apache.http.nio.entity.NStringEntity; import org.apache.logging.log4j.util.Strings; +import org.opensearch.Version; import org.opensearch.client.Request; import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.WarningsHandler; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContent; @@ -239,4 +251,66 @@ private static NamedXContentRegistry getXContentRegistry() { entries.add(LocalSampleCalculatorInput.XCONTENT_REGISTRY); return new NamedXContentRegistry(entries); } + + public static ClusterState state( + ClusterName name, + String indexName, + String mapping, + DiscoveryNode localNode, + DiscoveryNode clusterManagerNode, + List allNodes + ) throws IOException { + DiscoveryNodes.Builder discoBuilder = DiscoveryNodes.builder(); + for (DiscoveryNode node : allNodes) { + discoBuilder.add(node); + } + if (clusterManagerNode != null) { + discoBuilder.masterNodeId(clusterManagerNode.getId()); + } + discoBuilder.localNodeId(localNode.getId()); + + Settings indexSettings = Settings + .builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); + IndexMetadata indexMetaData = IndexMetadata.builder(indexName).settings(existingSettings).putMapping(mapping).build(); + + final ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(indexName, indexMetaData) + .build(); + ClusterState clusterState = ClusterState.builder(name).metadata(Metadata.builder().indices(indices).build()).build(); + + return clusterState; + } + + public static ClusterState state(int numDataNodes, String indexName, String mapping) throws IOException { + DiscoveryNode clusterManagerNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + List allNodes = new ArrayList<>(); + allNodes.add(clusterManagerNode); + for (int i = 1; i <= numDataNodes - 1; i++) { + allNodes + .add( + new DiscoveryNode( + "foo" + i, + "foo" + i, + new TransportAddress(InetAddress.getLoopbackAddress(), 9300 + i), + Collections.emptyMap(), + Collections.singleton(DATA_ROLE), + Version.CURRENT + ) + ); + } + return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes); + } }