Skip to content

Commit

Permalink
Change confidenceLevel from boolean to Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavmuk04 authored and kaikalur committed Jun 11, 2024
1 parent e0dac35 commit f876f17
Show file tree
Hide file tree
Showing 24 changed files with 127 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Optional;

import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static java.lang.Math.min;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -66,6 +67,11 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, Stat
public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<VariableReferenceExpression> groupByVariables, Map<VariableReferenceExpression, Aggregation> aggregations)
{
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();

if (isGlobalAggregation(groupByVariables)) {
result.setConfidence(FACT);
}

for (VariableReferenceExpression groupByVariable : groupByVariables) {
VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable);
result.addVariableStatistics(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> {
Expand Down Expand Up @@ -99,4 +105,9 @@ private static VariableStatsEstimate estimateAggregationStats(Aggregation aggreg
// TODO implement simple aggregations like: min, max, count, sum
return VariableStatsEstimate.unknown();
}

private static boolean isGlobalAggregation(Collection<VariableReferenceExpression> groupingKeys)
{
return groupingKeys.isEmpty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.Optional;

import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.plan.Patterns.enforceSingleRow;

public class EnforceSingleRowStatsRule
Expand All @@ -44,6 +45,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(EnforceSingleRowNode node,
{
return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats.getStats(node.getSource()))
.setOutputRowCount(1)
.setConfidence(FACT)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms;
import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.plan.Patterns.exchange;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
Expand All @@ -51,13 +53,13 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsPr
{
Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
double totalSize = 0;
boolean confident = true;
ConfidenceLevel confidenceLevel = FACT;
for (int i = 0; i < node.getSources().size(); i++) {
PlanNode source = node.getSources().get(i);
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source);
totalSize += sourceStats.getOutputSizeInBytes();
if (!sourceStats.isConfident()) {
confident = false;
if (sourceStats.confidenceLevel().ordinal() < confidenceLevel.ordinal()) {
confidenceLevel = sourceStats.confidenceLevel();
}

PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputVariables(sourceStats, node.getInputs().get(i), node.getOutputVariables());
Expand All @@ -74,7 +76,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsPr
verify(estimate.isPresent());
return Optional.of(buildFrom(estimate.get())
.setTotalSize(totalSize)
.setConfident(confident)
.setConfidence(confidenceLevel)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(LimitNode node, StatsProvi
// LIMIT actually limits (or when there was no row count estimated for source)
return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats)
.setOutputRowCount(node.getCount())
.setConfidence(sourceStats.confidenceLevel())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import java.util.Set;
import java.util.function.Function;

import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.util.MoreMath.firstNonNaN;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
Expand All @@ -51,7 +53,7 @@
public class PlanNodeStatsEstimate
{
private static final double DEFAULT_DATA_SIZE_PER_COLUMN = 50;
private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown());
private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, LOW, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown());

private final double outputRowCount;
private final double totalSize;
Expand All @@ -74,7 +76,7 @@ public static PlanNodeStatsEstimate unknown()
public PlanNodeStatsEstimate(
@JsonProperty("outputRowCount") double outputRowCount,
@JsonProperty("totalSize") double totalSize,
@JsonProperty("confident") boolean confident,
@JsonProperty("confident") ConfidenceLevel confidenceLevel,
@JsonProperty("variableStatistics") Map<VariableReferenceExpression, VariableStatsEstimate> variableStatistics,
@JsonProperty("joinNodeStatsEstimate") JoinNodeStatsEstimate joinNodeStatsEstimate,
@JsonProperty("tableWriterNodeStatsEstimate") TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate,
Expand All @@ -83,12 +85,12 @@ public PlanNodeStatsEstimate(
this(outputRowCount,
totalSize,
HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")),
new CostBasedSourceInfo(confident), joinNodeStatsEstimate, tableWriterNodeStatsEstimate, partialAggregationStatsEstimate);
new CostBasedSourceInfo(confidenceLevel), joinNodeStatsEstimate, tableWriterNodeStatsEstimate, partialAggregationStatsEstimate);
}

private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean confident, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
private PlanNodeStatsEstimate(double outputRowCount, double totalSize, ConfidenceLevel confidenceLevel, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
{
this(outputRowCount, totalSize, variableStatistics, new CostBasedSourceInfo(confident));
this(outputRowCount, totalSize, variableStatistics, new CostBasedSourceInfo(confidenceLevel));
}

public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics, SourceInfo sourceInfo)
Expand Down Expand Up @@ -126,9 +128,9 @@ public double getTotalSize()
}

@JsonProperty
public boolean isConfident()
public ConfidenceLevel confidenceLevel()
{
return sourceInfo.isConfident();
return sourceInfo.confidenceLevel();
}

public SourceInfo getSourceInfo()
Expand Down Expand Up @@ -327,7 +329,7 @@ public PlanStatisticsWithSourceInfo toPlanStatisticsWithSourceInfo(PlanNodeId id
new PlanStatistics(
Estimate.estimateFromDouble(outputRowCount),
Estimate.estimateFromDouble(totalSize),
sourceInfo.isConfident() ? 1 : 0,
sourceInfo.confidenceLevel() == LOW ? 0 : 1,
new JoinNodeStatistics(
Estimate.estimateFromDouble(joinNodeStatsEstimate.getNullJoinBuildKeyCount()),
Estimate.estimateFromDouble(joinNodeStatsEstimate.getJoinBuildKeyCount()),
Expand All @@ -349,27 +351,27 @@ public static Builder builder()
// we should propagate totalSize as default to simplify the relevant operations in rules that do not change this field.
public static Builder buildFrom(PlanNodeStatsEstimate other)
{
return new Builder(other.getOutputRowCount(), NaN, other.isConfident(), other.variableStatistics);
return new Builder(other.getOutputRowCount(), NaN, other.confidenceLevel(), other.variableStatistics);
}

public static final class Builder
{
private double outputRowCount;
private double totalSize;
private boolean confident;
private ConfidenceLevel confidenceLevel;
private PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics;
private PartialAggregationStatsEstimate partialAggregationStatsEstimate;

public Builder()
{
this(NaN, NaN, false, HashTreePMap.empty());
this(NaN, NaN, LOW, HashTreePMap.empty());
}

private Builder(double outputRowCount, double totalSize, boolean confident, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
private Builder(double outputRowCount, double totalSize, ConfidenceLevel confidenceLevel, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
{
this.outputRowCount = outputRowCount;
this.totalSize = totalSize;
this.confident = confident;
this.confidenceLevel = confidenceLevel;
this.variableStatistics = variableStatistics;
this.partialAggregationStatsEstimate = PartialAggregationStatsEstimate.unknown();
}
Expand All @@ -386,9 +388,9 @@ public Builder setTotalSize(double totalSize)
return this;
}

public Builder setConfident(boolean confident)
public Builder setConfidence(ConfidenceLevel confidenceLevel)
{
this.confident = confident;
this.confidenceLevel = confidenceLevel;
return this;
}

Expand Down Expand Up @@ -420,7 +422,7 @@ public PlanNodeStatsEstimate build()
{
return new PlanNodeStatsEstimate(outputRowCount,
totalSize,
confident,
confidenceLevel,
variableStatistics,
JoinNodeStatsEstimate.unknown(),
TableWriterNodeStatsEstimate.unknown(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static java.util.Objects.requireNonNull;

Expand All @@ -50,9 +52,12 @@ public Pattern<ProjectNode> getPattern()
protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());

boolean noChange = noChangeToSourceColumns(node);
ConfidenceLevel newConfidence = noChange ? sourceStats.confidenceLevel() : LOW;
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount())
.setConfident(sourceStats.isConfident() && noChangeToSourceColumns(node));
.setConfidence(newConfidence);

for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
RowExpression expression = entry.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.plan.Patterns.tableScan;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -69,7 +70,9 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsP
return Optional.of(PlanNodeStatsEstimate.builder()
.setOutputRowCount(tableStatistics.getRowCount().getValue())
.setTotalSize(tableStatistics.getTotalSize().getValue())
.setConfident(true)

// TODO Handle the confidence level properly when filters are pushed into the tablescan
.setConfidence(FACT)
.addVariableStatistics(outputVariableStats)
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.cost.StatsUtil.toStatsRepresentation;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.evaluateConstantRowExpression;
import static com.facebook.presto.sql.planner.plan.Patterns.values;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -60,7 +61,7 @@ public Optional<PlanNodeStatsEstimate> calculate(ValuesNode node, StatsProvider
{
PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
statsBuilder.setOutputRowCount(node.getRows().size())
.setConfident(true);
.setConfidence(FACT);

for (int variableId = 0; variableId < node.getOutputVariables().size(); ++variableId) {
VariableReferenceExpression variable = node.getOutputVariables().get(variableId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER;
import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan;
Expand Down Expand Up @@ -337,19 +339,19 @@ private boolean partialAggregationNotUseful(AggregationNode aggregationNode, Exc
double inputSize = exchangeStats.getOutputSizeInBytes(exchangeNode);
double outputSize = aggregationStats.getOutputSizeInBytes(aggregationNode);
PartialAggregationStatsEstimate partialAggregationStatsEstimate = aggregationStats.getPartialAggregationStatsEstimate();
boolean isConfident = exchangeStats.isConfident();
ConfidenceLevel confidenceLevel = exchangeStats.confidenceLevel();
// keep old behavior of skipping partial aggregation only for single-key aggregations
boolean numberOfKeyCheck = usePartialAggregationHistory(context.getSession()) || numAggregationKeys == 1;
if (!isUnknown(partialAggregationStatsEstimate) && usePartialAggregationHistory(context.getSession())) {
isConfident = aggregationStats.isConfident();
confidenceLevel = aggregationStats.confidenceLevel();
// use rows instead of bytes when use_partial_aggregation_history flag is on
inputSize = partialAggregationStatsEstimate.getInputRowCount();
outputSize = partialAggregationStatsEstimate.getOutputRowCount();
}
double byteReductionThreshold = getPartialAggregationByteReductionThreshold(context.getSession());

// calling this function means we are using a cost-based strategy for this optimization
return numberOfKeyCheck && isConfident && outputSize > inputSize * byteReductionThreshold;
return numberOfKeyCheck && confidenceLevel != LOW && outputSize > inputSize * byteReductionThreshold;
}

private static boolean isLambda(RowExpression rowExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.google.common.collect.ImmutableList.toImmutableList;

public class AggregationNodeUtils
Expand Down Expand Up @@ -77,7 +78,7 @@ public static boolean isAllLowCardinalityGroupByKeys(AggregationNode aggregation
List<VariableReferenceExpression> groupbyKeys = aggregationNode.getGroupingSets().getGroupingKeys().stream().collect(Collectors.toList());
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
PlanNodeStatsEstimate estimate = statsProvider.getStats(scanNode);
if (!estimate.isConfident()) {
if (estimate.confidenceLevel() == LOW) {
// For safety, we assume they are low card if not confident
// TODO(kaikalur) : maybe return low card only for partition keys if/when we can detect that
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.cost.EstimateAssertion.assertEstimateEquals;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import static com.google.common.collect.Sets.union;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -51,9 +52,9 @@ public PlanNodeStatsAssertion totalSize(double expected)
return this;
}

public PlanNodeStatsAssertion confident(boolean expected)
public PlanNodeStatsAssertion confident(ConfidenceLevel expected)
{
assertEquals(actual.isConfident(), expected);
assertEquals(actual.confidenceLevel(), expected);
return this;
}

Expand Down Expand Up @@ -100,7 +101,7 @@ public PlanNodeStatsAssertion variablesWithKnownStats(VariableReferenceExpressio
public PlanNodeStatsAssertion equalTo(PlanNodeStatsEstimate expected)
{
assertEstimateEquals(actual.getOutputRowCount(), expected.getOutputRowCount(), "outputRowCount mismatch");
assertEquals(actual.isConfident(), expected.isConfident());
assertEquals(actual.confidenceLevel(), expected.confidenceLevel());

for (VariableReferenceExpression variable : union(expected.getVariablesWithKnownStatistics(), actual.getVariablesWithKnownStatistics())) {
assertVariableStatsEqual(variable, actual.getVariableStatistics(variable), expected.getVariableStatistics(variable));
Expand Down
Loading

0 comments on commit f876f17

Please sign in to comment.