diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java index a7516a6fd6b24..fb259c2ac6173 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java @@ -68,6 +68,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation private static final ParseField REHASH = new ParseField("rehash").withAllDeprecated("no replacement - values will always be rehashed"); public static final ParseField PRECISION_THRESHOLD_FIELD = new ParseField("precision_threshold"); + public static final ParseField EXECUTION_HINT_FIELD = new ParseField(("execution_hint")); public static final ObjectParser PARSER = ObjectParser.fromBuilder( NAME, @@ -76,6 +77,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation static { ValuesSourceAggregationBuilder.declareFields(PARSER, true, false, false); PARSER.declareLong(CardinalityAggregationBuilder::precisionThreshold, CardinalityAggregationBuilder.PRECISION_THRESHOLD_FIELD); + PARSER.declareString(CardinalityAggregationBuilder::executionHint, CardinalityAggregationBuilder.EXECUTION_HINT_FIELD); PARSER.declareLong((b, v) -> {/*ignore*/}, REHASH); } @@ -85,6 +87,8 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { private Long precisionThreshold = null; + private String executionHint = null; + public CardinalityAggregationBuilder(String name) { super(name); } @@ -96,6 +100,7 @@ public CardinalityAggregationBuilder( ) { super(clone, factoriesBuilder, metadata); this.precisionThreshold = clone.precisionThreshold; + this.executionHint = clone.executionHint; } @Override @@ -111,6 +116,7 @@ public CardinalityAggregationBuilder(StreamInput in) throws IOException { if (in.readBoolean()) { precisionThreshold = in.readLong(); } + executionHint = in.readOptionalString(); } @Override @@ -125,6 +131,7 @@ protected void innerWriteTo(StreamOutput out) throws IOException { if (hasPrecisionThreshold) { out.writeLong(precisionThreshold); } + out.writeOptionalString(executionHint); } @Override @@ -155,6 +162,13 @@ public Long precisionThreshold() { return precisionThreshold; } + public CardinalityAggregationBuilder executionHint(String executionHint) { + this.executionHint = executionHint; + return this; + } + + public String executionHint() { return executionHint; } + @Override protected CardinalityAggregatorFactory innerBuild( QueryShardContext queryShardContext, @@ -162,7 +176,7 @@ protected CardinalityAggregatorFactory innerBuild( AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder ) throws IOException { - return new CardinalityAggregatorFactory(name, config, precisionThreshold, queryShardContext, parent, subFactoriesBuilder, metadata); + return new CardinalityAggregatorFactory(name, config, precisionThreshold, executionHint, queryShardContext, parent, subFactoriesBuilder, metadata); } @Override @@ -170,12 +184,15 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th if (precisionThreshold != null) { builder.field(PRECISION_THRESHOLD_FIELD.getPreferredName(), precisionThreshold); } + if (executionHint != null) { + builder.field(EXECUTION_HINT_FIELD.getPreferredName(), executionHint); + } return builder; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), precisionThreshold); + return Objects.hash(super.hashCode(), precisionThreshold, executionHint); } @Override @@ -184,7 +201,8 @@ public boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) return false; if (super.equals(obj) == false) return false; CardinalityAggregationBuilder other = (CardinalityAggregationBuilder) obj; - return Objects.equals(precisionThreshold, other.precisionThreshold); + return Objects.equals(precisionThreshold, other.precisionThreshold) + && Objects.equals(executionHint, other.executionHint); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 0f3d975960364..d24f2cece1904 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -89,6 +89,7 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private static final Logger logger = LogManager.getLogger(CardinalityAggregator.class); + private final CardinalityAggregatorFactory.ExecutionMode executionMode; private final int precision; private final ValuesSource valuesSource; @@ -111,6 +112,7 @@ public CardinalityAggregator( String name, ValuesSourceConfig valuesSourceConfig, int precision, + CardinalityAggregatorFactory.ExecutionMode executionMode, SearchContext context, Aggregator parent, Map metadata @@ -121,6 +123,7 @@ public CardinalityAggregator( this.precision = precision; this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1); this.valuesSourceConfig = valuesSourceConfig; + this.executionMode = executionMode; } @Override @@ -129,6 +132,7 @@ public ScoreMode scoreMode() { } private Collector pickCollector(LeafReaderContext ctx) throws IOException { + logger.info("ValuesSource Type: " + valuesSource); if (valuesSource == null) { emptyCollectorsUsed++; return new EmptyCollector(); @@ -151,6 +155,9 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { if (maxOrd == 0) { emptyCollectorsUsed++; return new EmptyCollector(); + } else if (executionMode == CardinalityAggregatorFactory.ExecutionMode.ORDINAL) { // Force OrdinalsCollector + ordinalsCollectorsUsed++; + collector = new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); } else { final long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd); final long countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(precision); @@ -261,6 +268,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc postCollectLastCollector(); collector = pickCollector(ctx); + logger.info("Collector chosen: " + collector); return collector; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java index 980667b45324e..ff8ea80c7b665 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java @@ -53,12 +53,45 @@ */ class CardinalityAggregatorFactory extends ValuesSourceAggregatorFactory { + public static enum ExecutionMode { + + UNSET(null), + DIRECT("direct"), + ORDINAL("ordinal"); + + private final String hintString; + + ExecutionMode(String hintString) { + this.hintString = hintString; + } + + public static ExecutionMode fromString(String value) { + if (value == null) { + return UNSET; + } + switch(value) { + case "direct": return DIRECT; + case "ordinal": return ORDINAL; + default: + throw new IllegalArgumentException("Unknown `execution_hint`: [" + value + "], expected any of [direct, ordinals]"); + } + } + + @Override + public String toString() { + return hintString; + } + } + + private final ExecutionMode executionMode; + private final Long precisionThreshold; CardinalityAggregatorFactory( String name, ValuesSourceConfig config, Long precisionThreshold, + String executionHint, QueryShardContext queryShardContext, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, @@ -66,6 +99,7 @@ class CardinalityAggregatorFactory extends ValuesSourceAggregatorFactory { ) throws IOException { super(name, config, queryShardContext, parent, subFactoriesBuilder, metadata); this.precisionThreshold = precisionThreshold; + this.executionMode = ExecutionMode.fromString(executionHint); } public static void registerAggregators(ValuesSourceRegistry.Builder builder) { @@ -74,7 +108,7 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { @Override protected Aggregator createUnmapped(SearchContext searchContext, Aggregator parent, Map metadata) throws IOException { - return new CardinalityAggregator(name, config, precision(), searchContext, parent, metadata); + return new CardinalityAggregator(name, config, precision(), executionMode, searchContext, parent, metadata); } @Override @@ -86,7 +120,7 @@ protected Aggregator doCreateInternal( ) throws IOException { return queryShardContext.getValuesSourceRegistry() .getAggregator(CardinalityAggregationBuilder.REGISTRY_KEY, config) - .build(name, config, precision(), searchContext, parent, metadata); + .build(name, config, precision(), executionMode, searchContext, parent, metadata); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java index d5cb0242762fd..b98ffbcc1b8e3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java @@ -49,6 +49,7 @@ Aggregator build( String name, ValuesSourceConfig valuesSourceConfig, int precision, + CardinalityAggregatorFactory.ExecutionMode executionMode, SearchContext context, Aggregator parent, Map metadata