Skip to content

Commit

Permalink
includes Count Aggregator
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <sarthagg@amazon.com>
  • Loading branch information
sarthakaggarwal97 committed Jun 27, 2024
1 parent d7ddd36 commit d90486f
Show file tree
Hide file tree
Showing 21 changed files with 448 additions and 342 deletions.
195 changes: 99 additions & 96 deletions server/src/main/java/org/apache/lucene/index/BaseStarTreeBuilder.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public <T> Void visit(Binding<T> binding) {
scopeInstancesInUse.put(scope, binding.getSource());
}
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@
* @opensearch.experimental
*/
public class CountValueAggregator implements ValueAggregator<Double> {
public static final StarTreeNumericType STAR_TREE_NUMERIC_TYPE = StarTreeNumericType.DOUBLE;
public static final StarTreeNumericType VALUE_AGGREGATOR_TYPE = StarTreeNumericType.DOUBLE;

@Override
public MetricStat getAggregationType() {
return MetricStat.COUNT;
}

@Override
public StarTreeNumericType getStarTreeNumericType() {
return STAR_TREE_NUMERIC_TYPE;
public StarTreeNumericType getAggregatedValueType() {
return VALUE_AGGREGATOR_TYPE;
}

@Override
public Double getInitialAggregatedValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
public Double getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
return 1.0;
}

@Override
public Double applySegmentRawValue(Double value, Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
public Double mergeAggregatedValueAndSegmentValue(Double value, Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
return value + 1;
}

@Override
public Double applyAggregatedValue(Double value, Double aggregatedValue) {
public Double mergeAggregatedValues(Double value, Double aggregatedValue) {
return value + aggregatedValue;
}

@Override
public Double getAggregatedValue(Double value) {
public Double getInitialAggregatedValue(Double value) {
return value;
}

Expand All @@ -58,17 +58,17 @@ public int getMaxAggregatedValueByteSize() {
public Long toLongValue(Double value) {
try {
return NumericUtils.doubleToSortableLong(value);
} catch (IllegalArgumentException | NullPointerException | IllegalStateException e) {
throw new IllegalArgumentException("Cannot convert " + value + " to sortable long", e);
} catch (Exception e) {
throw new IllegalStateException("Cannot convert " + value + " to sortable long", e);
}
}

@Override
public Double toStarTreeNumericTypeValue(Long value, StarTreeNumericType type) {
try {
return type.getDoubleValue(value);
} catch (IllegalArgumentException | NullPointerException | IllegalStateException e) {
throw new IllegalArgumentException("Cannot convert " + value + " to sortable aggregation type", e);
} catch (Exception e) {
throw new IllegalStateException("Cannot convert " + value + " to sortable aggregation type", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,34 @@
* Builds aggregation function and doc values field pair to support various aggregations
* @opensearch.experimental
*/
public class MetricAggregationDescriptor implements Comparable<MetricAggregationDescriptor> {

public static final String DELIMITER = "__";
public static final String STAR = "*";
public static final MetricAggregationDescriptor COUNT_STAR = new MetricAggregationDescriptor(
MetricStat.COUNT,
STAR,
IndexNumericFieldData.NumericType.DOUBLE,
null
);

private final String metricStatName;
public class MetricAggregatorInfo implements Comparable<MetricAggregatorInfo> {

public static final String DELIMITER = "_";
private final String metric;
private final String starFieldName;
private final MetricStat metricStat;
private final String field;
private final ValueAggregator valueAggregators;
private final StarTreeNumericType starTreeNumericType;
private final DocIdSetIterator metricStatReader;

/**
* Constructor for MetricAggregationDescriptor
* Constructor for MetricAggregatorInfo
*/
public MetricAggregationDescriptor(
public MetricAggregatorInfo(
MetricStat metricStat,
String field,
String starFieldName,
IndexNumericFieldData.NumericType numericType,
DocIdSetIterator metricStatReader
) {
this.metricStat = metricStat;
this.valueAggregators = ValueAggregatorFactory.getValueAggregator(metricStat);
this.starTreeNumericType = StarTreeNumericType.fromNumericType(numericType);
this.metricStatReader = metricStatReader;
if (metricStat == MetricStat.COUNT) {
this.field = STAR;
} else {
this.field = field;
}
this.metricStatName = toFieldName();
this.field = field;
this.starFieldName = starFieldName;
this.metric = toFieldName();
}

/**
Expand All @@ -74,8 +65,8 @@ public String getField() {
/**
* @return the metric stat name
*/
public String getMetricStatName() {
return metricStatName;
public String getMetric() {
return metric;
}

/**
Expand All @@ -86,9 +77,9 @@ public ValueAggregator getValueAggregators() {
}

/**
* @return star tree numeric type
* @return star tree aggregated value type
*/
public StarTreeNumericType getStarTreeNumericType() {
public StarTreeNumericType getAggregatedValueType() {
return starTreeNumericType;
}

Expand All @@ -103,14 +94,7 @@ public DocIdSetIterator getMetricStatReader() {
* @return field name with metric type and field
*/
public String toFieldName() {
return toFieldName(metricStat, field);
}

/**
* Builds field name with metric type and field
*/
public static String toFieldName(MetricStat metricType, String field) {
return metricType.getTypeName() + DELIMITER + field;
return starFieldName + DELIMITER + field + DELIMITER + metricStat.getTypeName();
}

@Override
Expand All @@ -123,8 +107,8 @@ public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj instanceof MetricAggregationDescriptor) {
MetricAggregationDescriptor anotherPair = (MetricAggregationDescriptor) obj;
if (obj instanceof MetricAggregatorInfo) {
MetricAggregatorInfo anotherPair = (MetricAggregatorInfo) obj;
return metricStat == anotherPair.metricStat && field.equals(anotherPair.field);
}
return false;
Expand All @@ -136,9 +120,9 @@ public String toString() {
}

@Override
public int compareTo(MetricAggregationDescriptor other) {
return Comparator.comparing((MetricAggregationDescriptor o) -> o.field)
.thenComparing((MetricAggregationDescriptor o) -> o.metricStat)
public int compareTo(MetricAggregatorInfo other) {
return Comparator.comparing((MetricAggregatorInfo o) -> o.field)
.thenComparing((MetricAggregatorInfo o) -> o.metricStat)
.compare(this, other);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,41 @@
*/
public class SumValueAggregator implements ValueAggregator<Double> {

public static final StarTreeNumericType STAR_TREE_NUMERIC_TYPE = StarTreeNumericType.DOUBLE;
public static final StarTreeNumericType VALUE_AGGREGATOR_TYPE = StarTreeNumericType.DOUBLE;

@Override
public MetricStat getAggregationType() {
return MetricStat.SUM;
}

@Override
public StarTreeNumericType getStarTreeNumericType() {
return STAR_TREE_NUMERIC_TYPE;
public StarTreeNumericType getAggregatedValueType() {
return VALUE_AGGREGATOR_TYPE;
}

@Override
public Double getInitialAggregatedValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
public Double getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
return starTreeNumericType.getDoubleValue(segmentDocValue);
}

@Override
public Double applySegmentRawValue(Double value, Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
public Double mergeAggregatedValueAndSegmentValue(Double value, Long segmentDocValue, StarTreeNumericType starTreeNumericType) {
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
kahanSummation.add(value);
kahanSummation.add(starTreeNumericType.getDoubleValue(segmentDocValue));
return kahanSummation.value();
}

@Override
public Double applyAggregatedValue(Double value, Double aggregatedValue) {
public Double mergeAggregatedValues(Double value, Double aggregatedValue) {
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
kahanSummation.add(value);
kahanSummation.add(aggregatedValue);
return kahanSummation.value();
}

@Override
public Double getAggregatedValue(Double value) {
public Double getInitialAggregatedValue(Double value) {
return value;
}

Expand All @@ -66,17 +66,17 @@ public int getMaxAggregatedValueByteSize() {
public Long toLongValue(Double value) {
try {
return NumericUtils.doubleToSortableLong(value);
} catch (IllegalArgumentException | NullPointerException | IllegalStateException e) {
throw new IllegalArgumentException("Cannot convert " + value + " to sortable long", e);
} catch (Exception e) {
throw new IllegalStateException("Cannot convert " + value + " to sortable long", e);
}
}

@Override
public Double toStarTreeNumericTypeValue(Long value, StarTreeNumericType type) {
try {
return type.getDoubleValue(value);
} catch (IllegalArgumentException | NullPointerException | IllegalStateException e) {
throw new IllegalArgumentException("Cannot convert " + value + " to sortable aggregation type", e);
} catch (Exception e) {
throw new IllegalStateException("Cannot convert " + value + " to sortable aggregation type", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ public interface ValueAggregator<A> {
/**
* Returns the data type of the aggregated value.
*/
StarTreeNumericType getStarTreeNumericType();
StarTreeNumericType getAggregatedValueType();

/**
* Returns the initial aggregated value.
*/
A getInitialAggregatedValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType);
A getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue, StarTreeNumericType starTreeNumericType);

/**
* Applies a segment doc value to the current aggregated value.
*/
A applySegmentRawValue(A value, Long segmentDocValue, StarTreeNumericType starTreeNumericType);
A mergeAggregatedValueAndSegmentValue(A value, Long segmentDocValue, StarTreeNumericType starTreeNumericType);

/**
* Applies an aggregated value to the current aggregated value.
*/
A applyAggregatedValue(A value, A aggregatedValue);
A mergeAggregatedValues(A value, A aggregatedValue);

/**
* Clones an aggregated value.
*/
A getAggregatedValue(A value);
A getInitialAggregatedValue(A value);

/**
* Returns the maximum size in bytes of the aggregated values seen so far.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ public static StarTreeNumericType getAggregatedValueType(MetricStat aggregationT
switch (aggregationType) {
// other metric types (count, min, max, avg) will be supported in the future
case SUM:
return SumValueAggregator.STAR_TREE_NUMERIC_TYPE;
return SumValueAggregator.VALUE_AGGREGATOR_TYPE;
case COUNT:
return CountValueAggregator.STAR_TREE_NUMERIC_TYPE;
return CountValueAggregator.VALUE_AGGREGATOR_TYPE;
default:
throw new IllegalStateException("Unsupported aggregation type: " + aggregationType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ public enum StarTreeNumericType {
HALF_FLOAT(IndexNumericFieldData.NumericType.HALF_FLOAT, StarTreeNumericTypeConverters::halfFloatPointToDouble),
FLOAT(IndexNumericFieldData.NumericType.FLOAT, StarTreeNumericTypeConverters::floatPointToDouble),
LONG(IndexNumericFieldData.NumericType.LONG, StarTreeNumericTypeConverters::longToDouble),
DOUBLE(IndexNumericFieldData.NumericType.DOUBLE, StarTreeNumericTypeConverters::sortableLongtoDouble);
DOUBLE(IndexNumericFieldData.NumericType.DOUBLE, StarTreeNumericTypeConverters::sortableLongtoDouble),
INT(IndexNumericFieldData.NumericType.INT, StarTreeNumericTypeConverters::intToDouble),
SHORT(IndexNumericFieldData.NumericType.SHORT, StarTreeNumericTypeConverters::shortToDouble),
UNSIGNED_LONG(IndexNumericFieldData.NumericType.UNSIGNED_LONG, StarTreeNumericTypeConverters::unsignedlongToDouble);

final IndexNumericFieldData.NumericType numericType;
final Function<Long, Double> converter;
Expand All @@ -44,6 +47,12 @@ public static StarTreeNumericType fromNumericType(IndexNumericFieldData.NumericT
return StarTreeNumericType.LONG;
case DOUBLE:
return StarTreeNumericType.DOUBLE;
case INT:
return StarTreeNumericType.INT;
case SHORT:
return StarTreeNumericType.SHORT;
case UNSIGNED_LONG:
return StarTreeNumericType.UNSIGNED_LONG;
default:
throw new UnsupportedOperationException("Unknown numeric type [" + numericType + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.sandbox.document.HalfFloatPoint;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.Numbers;
import org.opensearch.common.annotation.ExperimentalApi;

/**
Expand All @@ -31,7 +32,20 @@ public static double longToDouble(Long value) {
return (double) value;
}

public static double intToDouble(Long value) {
return (double) value;
}

public static double shortToDouble(Long value) {
return (double) value;
}

public static Double sortableLongtoDouble(Long value) {
return NumericUtils.sortableLongToDouble(value);
}

public static double unsignedlongToDouble(Long value) {
return Numbers.unsignedLongToDouble(value);
}

}
Loading

0 comments on commit d90486f

Please sign in to comment.