Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce mode and compression param resolution #2034

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Compatible with OpenSearch 2.17.0
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
* Add mode/compression configuration support for disk-based vector search [#2034](https://github.com/opensearch-project/k-NN/pull/2034)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
### Bug Fixes
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand Down Expand Up @@ -182,7 +184,11 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
shardPath,
spaceType,
modelId,
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
FieldInfoExtractor.extractQuantizationConfig(fieldInfo) == QuantizationConfig.EMPTY
? VectorDataType.get(
fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())
)
: VectorDataType.BINARY
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,45 +88,41 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ);
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}
if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e);
return null;
byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
import org.apache.lucene.util.IOUtils;
import org.opensearch.common.UUIDs;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
Expand All @@ -50,8 +46,8 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader {

public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException {
this.segmentReadState = state;
primeQuantizationStateCache();
this.flatVectorsReader = flatVectorsReader;
primeQuantizationStateCache();
}

/**
Expand Down Expand Up @@ -197,28 +193,9 @@ public long ramBytesUsed() {

private void primeQuantizationStateCache() throws IOException {
quantizationStateCacheKeyPerField = new HashMap<>();
Map<String, byte[]> stateMap = KNN990QuantizationStateReader.read(segmentReadState);
for (Map.Entry<String, byte[]> entry : stateMap.entrySet()) {
FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey());
QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo);
if (quantizationParams instanceof ScalarQuantizationParams) {
QuantizationState quantizationState;
ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams;
switch (scalarQuantizationParams.getSqType()) {
case ONE_BIT:
quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue());
break;
case TWO_BIT:
case FOUR_BIT:
quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue());
break;
default:
throw new IllegalArgumentException("Unknown Scalar Quantization Type");
}
String cacheKey = UUIDs.base64UUID();
quantizationStateCacheKeyPerField.put(entry.getKey(), cacheKey);
quantizationStateCacheManager.addQuantizationState(cacheKey, quantizationState);
}
for (FieldInfo fieldInfo : segmentReadState.fieldInfos) {
String cacheKey = UUIDs.base64UUID();
quantizationStateCacheKeyPerField.put(fieldInfo.getName(), cacheKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
Expand Down Expand Up @@ -255,7 +257,12 @@ private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model mod
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID));
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob());
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) {
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
} else {
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
}

return parameters;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

/**
* This object provides additional context that the user does not provide when {@link KNNMethodContext} is
Expand All @@ -27,5 +29,10 @@ public final class KNNMethodConfigContext {
private VectorDataType vectorDataType;
private Integer dimension;
private Version versionCreated;
@Builder.Default
private Mode mode = Mode.NOT_CONFIGURED;
@Builder.Default
private CompressionLevel compressionLevel = CompressionLevel.NOT_CONFIGURED;

public static final KNNMethodConfigContext EMPTY = KNNMethodConfigContext.builder().build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.core.common.Strings;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.util.Arrays;
import java.util.Locale;
import java.util.stream.Collectors;

/**
* Enum representing the compression level for float vectors. Compression in this sense refers to compressing a
Expand All @@ -20,20 +19,23 @@
*/
@AllArgsConstructor
public enum CompressionLevel {
NOT_CONFIGURED(-1, ""),
x1(1, "1x"),
x2(2, "2x"),
x4(4, "4x"),
x8(8, "8x"),
x16(16, "16x"),
x32(32, "32x");
NOT_CONFIGURED(-1, "", null),
x1(1, "1x", null),
x2(2, "2x", null),
x4(4, "4x", new RescoreContext(1.0f)),
x8(8, "8x", new RescoreContext(1.5f)),
x16(16, "16x", new RescoreContext(2.0f)),
x32(32, "32x", new RescoreContext(2.0f));

// Internally, an empty string is easier to deal with them null. However, from the mapping,
// we do not want users to pass in the empty string and instead want null. So we make the conversion herex
static final String[] NAMES_ARRAY = Arrays.stream(CompressionLevel.values())
.map(compressionLevel -> compressionLevel == NOT_CONFIGURED ? null : compressionLevel.getName())
.collect(Collectors.toList())
.toArray(new String[0]);
public static final String[] NAMES_ARRAY = new String[] {
NOT_CONFIGURED.getName(),
x1.getName(),
x2.getName(),
x8.getName(),
x16.getName(),
x32.getName() };

/**
* Default is set to 1x and is a noop
Expand Down Expand Up @@ -62,6 +64,8 @@ public static CompressionLevel fromName(String name) {
private final int compressionLevel;
@Getter
private final String name;
@Getter
private final RescoreContext defaultRescoreContext;

/**
* Gets the number of bits used to represent a float in order to achieve this compression. For instance, for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ default Optional<KNNMethodContext> getKnnMethodContext() {
return Optional.empty();
}

/**
* Return the mode to be used for this field
*
* @return {@link Mode}
*/
default Mode getMode() {
return Mode.NOT_CONFIGURED;
}

/**
* Return compression level to be used for this field
*
* @return {@link CompressionLevel}
*/
default CompressionLevel getCompressionLevel() {
return CompressionLevel.NOT_CONFIGURED;
}

/**
*
* @return the dimension of the index; for model based indices, it will be null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
b.startObject(n);
v.toXContent(b, ToXContent.EMPTY_PARAMS);
b.endObject();
}), m -> m.getMethodComponentContext().getName()).setValidator(v -> {
if (v == null) return;

ValidationException validationException;
if (v.isTrainingRequired()) {
validationException = new ValidationException();
validationException.addValidationError(String.format(Locale.ROOT, "\"%s\" requires training.", KNN_METHOD));
throw validationException;
}
});
}), m -> m.getMethodComponentContext().getName());

protected final Parameter<String> mode = Parameter.restrictedStringParam(
KNNConstants.MODE_PARAMETER,
Expand Down Expand Up @@ -354,13 +345,34 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
} else if (builder.modelId.get() != null) {
validateFromModel(builder);
} else {
validateMode(builder);
resolveKNNMethodComponents(builder, parserContext);
validateFromKNNMethod(builder);
}

return builder;
}

private void validateMode(KNNVectorFieldMapper.Builder builder) {
boolean isKNNMethodContextConfigured = builder.originalParameters.getKnnMethodContext() != null;
boolean isModeConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured();
if (isModeConfigured && isKNNMethodContextConfigured) {
throw new MapperParsingException(
String.format(
Locale.ROOT,
"Compression and mode can not be specified in a \"method\" mapping configuration for field: %s",
builder.name
)
);
}

if (isModeConfigured && builder.vectorDataType.getValue() != VectorDataType.FLOAT) {
throw new MapperParsingException(
String.format(Locale.ROOT, "Compression and mode cannot be used for non-float32 data type for field %s", builder.name)
);
}
}

private void validateFromFlat(KNNVectorFieldMapper.Builder builder) {
if (builder.modelId.get() != null || builder.knnMethodContext.get() != null) {
throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false");
Expand All @@ -378,9 +390,15 @@ private void validateFromModel(KNNVectorFieldMapper.Builder builder) {
}

private void validateFromKNNMethod(KNNVectorFieldMapper.Builder builder) {
ValidationException validationException;
if (builder.originalParameters.getResolvedKnnMethodContext().isTrainingRequired()) {
validationException = new ValidationException();
validationException.addValidationError(String.format(Locale.ROOT, "\"%s\" requires training.", KNN_METHOD));
throw validationException;
}

if (builder.originalParameters.getResolvedKnnMethodContext() != null) {
ValidationException validationException = builder.originalParameters.getResolvedKnnMethodContext()
.validate(builder.knnMethodConfigContext);
validationException = builder.originalParameters.getResolvedKnnMethodContext().validate(builder.knnMethodConfigContext);
if (validationException != null) {
throw validationException;
}
Expand Down Expand Up @@ -410,9 +428,11 @@ private void validateCompressionAndModeNotSet(KNNVectorFieldMapper.Builder build
private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, ParserContext parserContext) {
builder.setKnnMethodConfigContext(
KNNMethodConfigContext.builder()
.vectorDataType(builder.vectorDataType.getValue())
.vectorDataType(builder.originalParameters.getVectorDataType())
.versionCreated(parserContext.indexVersionCreated())
.dimension(builder.dimension.getValue())
.dimension(builder.originalParameters.getDimension())
.mode(Mode.fromName(builder.originalParameters.getMode()))
.compressionLevel(CompressionLevel.fromName(builder.originalParameters.getCompressionLevel()))
.build()
);

Expand All @@ -421,8 +441,17 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa
builder.originalParameters.setResolvedKnnMethodContext(
createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated())
);
}
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.vectorDataType.getValue());
} else if (Mode.isConfigured(Mode.fromName(builder.mode.get()))
|| CompressionLevel.isConfigured(CompressionLevel.fromName(builder.compressionLevel.get()))) {
builder.originalParameters.setResolvedKnnMethodContext(
ModeBasedResolver.INSTANCE.resolveKNNMethodContext(
builder.knnMethodConfigContext.getMode(),
builder.knnMethodConfigContext.getCompressionLevel(),
false
)
);
}
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.originalParameters.getVectorDataType());
}

private boolean isKNNDisabled(Settings settings) {
Expand Down
Loading
Loading