Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index routing tool #110

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
7 changes: 4 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ dependencies {
compileOnly("com.google.guava:guava:33.0.0-jre")
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation 'com.knuddels:jtokkit:0.6.1'

// Plugin dependencies
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
Expand All @@ -119,16 +120,16 @@ dependencies {
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"])
compileOnly "org.opensearch:common-utils:${version}"
compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}"
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${version}"
implementation "org.opensearch:opensearch-job-scheduler-spi:${version}"


// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${version}"
zipArchive "org.opensearch.plugin:opensearch-anomaly-detection:${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${version}"
zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${version}"
// zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${version}"
// zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${version}"

// Test dependencies
testImplementation "org.opensearch.test:framework:${opensearch_version}"
Expand Down
96 changes: 92 additions & 4 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

package org.opensearch.agent;

import static org.opensearch.agent.job.Constants.INDEX_SUMMARY_JOB_THREAD_POOL;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.indices.IndicesHelper;
import org.opensearch.agent.job.IndexSummaryEmbeddingJob;
import org.opensearch.agent.job.MLClients;
import org.opensearch.agent.job.SkillsClusterStateEventListener;
import org.opensearch.agent.tools.IndexRoutingTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.RAGTool;
Expand All @@ -23,25 +29,41 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.index.IndexModule;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexingOperationListener;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.jobscheduler.spi.utils.LockService;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.spi.MLCommonsExtension;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SystemIndexPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import lombok.SneakyThrows;

public class ToolPlugin extends Plugin implements MLCommonsExtension {
public class ToolPlugin extends Plugin implements MLCommonsExtension, SystemIndexPlugin {

private Client client;
private ClusterService clusterService;
private NamedXContentRegistry xContentRegistry;
private MLClients mlClients;

private SkillsClusterStateEventListener clusterStateEventListener;

@SneakyThrows
@Override
Expand All @@ -62,6 +84,19 @@ public Collection<Object> createComponents(
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;

mlClients = new MLClients(client, xContentRegistry, clusterService);
IndicesHelper indicesHelper = new IndicesHelper(clusterService, client, mlClients);
LockService lockService = new LockService(client, clusterService);
clusterStateEventListener = new SkillsClusterStateEventListener(
clusterService,
client,
environment.settings(),
threadPool,
indicesHelper,
mlClients,
lockService
);

PPLTool.Factory.getInstance().init(client);
VisualizationsTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
Expand All @@ -72,7 +107,9 @@ public Collection<Object> createComponents(
SearchAnomalyDetectorsTool.Factory.getInstance().init(client);
SearchAnomalyResultsTool.Factory.getInstance().init(client);
SearchMonitorsTool.Factory.getInstance().init(client);
return Collections.emptyList();
IndexRoutingTool.Factory.getInstance().init(client, xContentRegistry, clusterService);

return List.of(clusterStateEventListener);
}

@Override
Expand All @@ -88,7 +125,58 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAlertsTool.Factory.getInstance(),
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance()
SearchMonitorsTool.Factory.getInstance(),
IndexRoutingTool.Factory.getInstance()
);
}

@Override
public List<Setting<?>> getSettings() {
return List
.of(
SkillsClusterStateEventListener.SKILLS_INDEX_SUMMARY_JOB_INTERVAL,
SkillsClusterStateEventListener.SKILLS_INDEX_SUMMARY_JOB_ENABLED
);
}

@Override
public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings settings) {
return List
.of(
new SystemIndexDescriptor(
IndexSummaryEmbeddingJob.INDEX_SUMMARY_EMBEDDING_INDEX,
"System index for storing index meta and simple data embedding"
)
);
}

@Override
public void onIndexModule(IndexModule indexModule) {
if (indexModule.getIndex().getName().equals(CommonValue.ML_AGENT_INDEX)) {
// watch on new agent created
indexModule.addIndexOperationListener(new IndexingOperationListener() {
@Override
public void postIndex(ShardId shardId, Engine.Index index, Engine.IndexResult result) {
if (result.isCreated()) {
clusterStateEventListener.onNewAgentCreated(index.id());
}
}
});

}
super.onIndexModule(indexModule);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
FixedExecutorBuilder indexSummaryJobThreadPool = new FixedExecutorBuilder(
settings,
INDEX_SUMMARY_JOB_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings) * 2,
100,
INDEX_SUMMARY_JOB_THREAD_POOL,
false
);
return List.of(indexSummaryJobThreadPool);
}
}
191 changes: 191 additions & 0 deletions src/main/java/org/opensearch/agent/indices/IndicesHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.indices;

import static org.opensearch.agent.job.IndexSummaryEmbeddingJob.INDEX_SUMMARY_EMBEDDING_FIELD_PREFIX;
import static org.opensearch.ml.common.CommonValue.META;
import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest;
import org.opensearch.agent.job.MLClients;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.model.ModelTensorOutput;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
@Log4j2
public class IndicesHelper {

ClusterService clusterService;
Client client;
MLClients mlClients;
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();

static {
for (SkillsIndexEnum index : SkillsIndexEnum.values()) {
indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false));
}
}

public void initIndexSummaryEmbeddingIndex(ActionListener<Boolean> listener) {
initIndexIfAbsent(SkillsIndexEnum.SKILLS_INDEX_SUMMARY, listener);
}

public void initIndexIfAbsent(SkillsIndexEnum skillsIndexEnum, ActionListener<Boolean> listener) {
try (
ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext();
InputStream settingIns = this.getClass().getResourceAsStream(skillsIndexEnum.getSetting());
InputStream mappingIns = this.getClass().getResourceAsStream(skillsIndexEnum.getMapping())
) {
String setting = new String(Objects.requireNonNull(settingIns).readAllBytes(), StandardCharsets.UTF_8);
String mapping = new String(Objects.requireNonNull(mappingIns).readAllBytes(), StandardCharsets.UTF_8);
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, threadContext::restore);

if (!clusterService.state().metadata().hasIndex(skillsIndexEnum.getIndexName())) {
ActionListener<CreateIndexResponse> actionListener = ActionListener.wrap(r -> {
if (r.isAcknowledged()) {
log.info("create index:{}", skillsIndexEnum.getIndexName());
internalListener.onResponse(true);
} else {
internalListener.onResponse(false);
}
}, e -> {
log.error("Failed to create index " + skillsIndexEnum, e);
internalListener.onFailure(e);
});

CreateIndexRequest request = new CreateIndexRequest(skillsIndexEnum.getIndexName())
.mapping(mapping)
.settings(setting, MediaTypeRegistry.JSON);
client.admin().indices().create(request, actionListener);
} else {
log.debug("index:{} is already created", skillsIndexEnum.getIndexName());
if (indexMappingUpdated.containsKey(skillsIndexEnum.getIndexName())
&& !indexMappingUpdated.get(skillsIndexEnum.getIndexName()).get()) {
shouldUpdateIndex(skillsIndexEnum.getIndexName(), skillsIndexEnum.getVersion(), ActionListener.wrap(r -> {
if (r) {
// return true if should update skillsIndexEnum
client
.admin()
.indices()
.putMapping(
new PutMappingRequest().indices(skillsIndexEnum.getIndexName()).source(mapping, MediaTypeRegistry.JSON),
ActionListener.wrap(response -> {
if (response.isAcknowledged()) {
internalListener.onResponse(true);
} else {
internalListener
.onFailure(new MLException("Failed to update skillsIndexEnum: " + skillsIndexEnum));
}
}, exception -> {
log.error("Failed to update skillsIndexEnum " + skillsIndexEnum, exception);
internalListener.onFailure(exception);
})
);
} else {
// no need to update skillsIndexEnum if it does not exist or the version is already up-to-date.
indexMappingUpdated.get(skillsIndexEnum.getIndexName()).set(true);
internalListener.onResponse(true);
}
}, e -> {
log.error("Failed to update skillsIndexEnum mapping", e);
internalListener.onFailure(e);
}));
} else {
// No need to update skillsIndexEnum if it's not system skillsIndexEnum or it's already updated.
internalListener.onResponse(true);
}
}
} catch (Exception e) {
log.error("Failed to init skillsIndexEnum " + skillsIndexEnum, e);
listener.onFailure(e);
}
}

/**
* Check if we should update index based on schema version.
* @param indexName index name
* @param newVersion new index mapping version
* @param listener action listener, if should update index, will pass true to its onResponse method
*/
public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener<Boolean> listener) {
IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName);
if (indexMetaData == null) {
listener.onResponse(Boolean.FALSE);
return;
}
Integer oldVersion = CommonValue.NO_SCHEMA_VERSION;
Map<String, Object> indexMapping = indexMetaData.mapping().getSourceAsMap();
Object meta = indexMapping.get(META);
if (meta instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, Object> metaMapping = (Map<String, Object>) meta;
Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD);
if (schemaVersion instanceof Integer) {
oldVersion = (Integer) schemaVersion;
}
}
listener.onResponse(newVersion > oldVersion);
}

public void addNewVectorField(String index, String modelId, ActionListener<Boolean> listener) {
IndexMetadata metadata = clusterService.state().metadata().index(index);
String fieldName = String.format(Locale.ROOT, "%s_%s", INDEX_SUMMARY_EMBEDDING_FIELD_PREFIX, modelId);
Map<String, Object> fieldMap = (Map<String, Object>) metadata.mapping().getSourceAsMap().get("properties");
boolean vectorFieldExists = fieldMap.containsKey(fieldName);
if (!vectorFieldExists) {
try (InputStream vectorFieldIns = this.getClass().getResourceAsStream("/vector_field.json")) {
String vectorField = new String(Objects.requireNonNull(vectorFieldIns).readAllBytes(), StandardCharsets.UTF_8);
Long dimension = getDimension(modelId);
StringSubstitutor substitutor = new StringSubstitutor(Map.of("model_id", modelId, "dimension", dimension));
PutMappingRequest request = new PutMappingRequest(index).source(substitutor.replace(vectorField), XContentType.JSON);
client.admin().indices().putMapping(request, ActionListener.wrap(r -> {
if (r.isAcknowledged()) {
listener.onResponse(true);
} else {
listener.onResponse(false);
}
}, listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
} else {
listener.onResponse(true);
}
}

private Long getDimension(String modelId) {
return mlClients.getEmbeddingResult(modelId, List.of("today is sunny"), true, mlTaskResponse -> {
ModelTensorOutput tensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
return tensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getShape()[0];
});
}
}
Loading