Skip to content

Commit

Permalink
Move mappers to separate files (#448) (#450)
Browse files Browse the repository at this point in the history
* Move mappers to separate files

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] authored Jul 21, 2022
1 parent cd669d0 commit 9b06925
Show file tree
Hide file tree
Showing 23 changed files with 312 additions and 235 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index;

import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.knn.index.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;

import java.io.Closeable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,19 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.common.KNNConstants;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.Explicit;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
Expand All @@ -36,9 +31,11 @@
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

Expand All @@ -47,17 +44,10 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;

/**
* Field Mapper for KNN vector type.
Expand All @@ -69,8 +59,6 @@
*/
public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper {

private static Logger logger = LogManager.getLogger(KNNVectorFieldMapper.class);

public static final String CONTENT_TYPE = "knn_vector";
public static final String KNN_FIELD = "knn_field";

Expand Down Expand Up @@ -99,11 +87,13 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
}
int value = XContentMapValues.nodeIntegerValue(o);
if (value > MAX_DIMENSION) {
throw new IllegalArgumentException("Dimension value cannot be greater than " + MAX_DIMENSION + " for vector: " + name);
throw new IllegalArgumentException(
String.format("Dimension value cannot be greater than %s for vector: %s", MAX_DIMENSION, name)
);
}

if (value <= 0) {
throw new IllegalArgumentException("Dimension value must be greater than 0 " + "for vector: " + name);
throw new IllegalArgumentException(String.format("Dimension value must be greater than 0 for vector: %s", name));
}
return value;
}, m -> toType(m).dimension);
Expand Down Expand Up @@ -285,12 +275,12 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
// is done before any mappers are built. Therefore, validation should be done during parsing
// so that it can fail early.
if (builder.knnMethodContext.get() != null && builder.modelId.get() != null) {
throw new IllegalArgumentException("Method and model can not be both specified in the mapping: " + name);
throw new IllegalArgumentException(String.format("Method and model can not be both specified in the mapping: %s", name));
}

// Dimension should not be null unless modelId is used
if (builder.dimension.getValue() == -1 && builder.modelId.get() == null) {
throw new IllegalArgumentException("Dimension value missing for vector: " + name);
throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name));
}

return builder;
Expand Down Expand Up @@ -337,7 +327,7 @@ public Query existsQuery(QueryShardContext context) {
public Query termQuery(Object value, QueryShardContext context) {
throw new QueryShardException(
context,
"KNN vector do not support exact searching, use KNN queries " + "instead: [" + name() + "]"
String.format("KNN vector do not support exact searching, use KNN queries instead: [%s]", name())
);
}

Expand Down Expand Up @@ -392,16 +382,39 @@ protected void parseCreateField(ParseContext context) throws IOException {

protected void parseCreateField(ParseContext context, int dimension) throws IOException {

if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable " + "update knn.plugin.enabled setting to true");
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

Optional<float[]> arrayOptional = getFloatsFromContext(context, dimension);

if (!arrayOptional.isPresent()) {
return;
}
final float[] array = arrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (fieldType.stored()) {
context.doc().add(new StoredField(name(), point.toString()));
}
context.path().remove();
}

void validateIfCircuitBreakerIsNotTriggered() {
if (KNNSettings.isCircuitBreakerTriggered()) {
throw new IllegalStateException(
"Indexing knn vector fields is rejected as circuit breaker triggered." + " Check _opendistro/_knn/stats for detailed state"
"Indexing knn vector fields is rejected as circuit breaker triggered. Check _opendistro/_knn/stats for detailed state"
);
}
}

void validateIfKNNPluginEnabled() {
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true");
}
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
Expand Down Expand Up @@ -438,7 +451,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return;
return Optional.empty();
}

if (dimension != vector.size()) {
Expand All @@ -451,14 +464,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
for (Float f : vector) {
array[i++] = f;
}

VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (fieldType.stored()) {
context.doc().add(new StoredField(name(), point.toString()));
}
context.path().remove();
return Optional.of(array);
}

@Override
Expand Down Expand Up @@ -505,187 +511,4 @@ public static class Defaults {
FIELD_TYPE.freeze();
}
}

/**
* Field mapper for original implementation
*/
protected static class LegacyFieldMapper extends KNNVectorFieldMapper {

protected String spaceType;
protected String m;
protected String efConstruction;

private LegacyFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
String spaceType,
String m,
String efConstruction
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.spaceType = spaceType;
this.m = m;
this.efConstruction = efConstruction;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, spaceType);
this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName());

// These are extra just for legacy
this.fieldType.putAttribute(HNSW_ALGO_M, m);
this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction);

this.fieldType.freeze();
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction).init(this);
}

static String getSpaceType(Settings indexSettings) {
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
if (spaceType == null) {
logger.info(
"[KNN] The setting \""
+ METHOD_PARAMETER_SPACE_TYPE
+ "\" was not set for the index. "
+ "Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE
);
return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
}
return spaceType;
}

static String getM(Settings indexSettings) {
String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey());
if (m == null) {
logger.info(
"[KNN] The setting \""
+ HNSW_ALGO_M
+ "\" was not set for the index. "
+ "Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M
);
return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M);
}
return m;
}

static String getEfConstruction(Settings indexSettings) {
String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey());
if (efConstruction == null) {
logger.info(
"[KNN] The setting \""
+ HNSW_ALGO_EF_CONSTRUCTION
+ "\" was not set for"
+ " the index. Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION
);
return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION);
}
return efConstruction;
}
}

/**
* Field mapper for method definition in mapping
*/
protected static class MethodFieldMapper extends KNNVectorFieldMapper {

private MethodFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
KNNMethodContext knnMethodContext
) {

super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.knnMethod = knnMethodContext;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue());

KNNEngine knnEngine = knnMethodContext.getKnnEngine();
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
this.fieldType.putAttribute(
PARAMETERS,
Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)))
);
} catch (IOException ioe) {
throw new RuntimeException("Unable to create KNNVectorFieldMapper: " + ioe);
}

this.fieldType.freeze();
}
}

/**
* Field mapper for model in mapping
*/
protected static class ModelFieldMapper extends KNNVectorFieldMapper {

private ModelFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
ModelDao modelDao,
String modelId
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.modelId = modelId;
this.modelDao = modelDao;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
}

@Override
protected void parseCreateField(ParseContext context) throws IOException {
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalStateException(
"Model \""
+ modelId
+ "\" from "
+ context.mapperService().index().getName()
+ "'s mapping does not exist. Because the "
+ "\""
+ MODEL_ID
+ "\" parameter is not updateable, this index will need to "
+ "be recreated with a valid model."
);
}

parseCreateField(context, modelMetadata.getDimension());
}
}
}
Loading

0 comments on commit 9b06925

Please sign in to comment.