From a7732f3e6df7d7d49cbf5f8327ec4b8d9709906b Mon Sep 17 00:00:00 2001 From: feilong-liu Date: Wed, 15 May 2024 16:36:34 -0700 Subject: [PATCH] Record estimation stats during query optimization When optimizer returns a optimized plan, it will also return the estimation of stats for each node with the plan, However, instead of returning the exact stats which are used in optimization, it's actually recalculating the stats. This can be a problem. For example, currently CBO returns empty stats if the aggregation step is not single for an aggregation This means that, we will not get any CBO stats for partial and final aggregation, and all other node which are downstream of the aggregation. In this PR, it will record the stats during query optimization. For the same node, later stats will override previous ones. --- .../main/java/com/facebook/presto/Session.java | 15 +++++++++++++++ .../facebook/presto/SystemSessionProperties.java | 12 ++++++++++++ .../facebook/presto/cost/CachingCostProvider.java | 6 +++++- .../presto/cost/CachingStatsProvider.java | 6 +++++- .../com/facebook/presto/cost/StatsAndCosts.java | 9 ++++++--- .../java/com/facebook/presto/sql/Optimizer.java | 2 +- .../presto/sql/analyzer/FeaturesConfig.java | 14 ++++++++++++++ .../facebook/presto/cost/TestCostCalculator.java | 4 ++-- .../presto/sql/analyzer/TestFeaturesConfig.java | 7 +++++-- .../presto/sql/planner/assertions/PlanAssert.java | 2 +- .../planner/iterative/rule/test/RuleAssert.java | 2 +- 11 files changed, 67 insertions(+), 12 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/Session.java b/presto-main/src/main/java/com/facebook/presto/Session.java index 7485bad80671..a9fbca0eee24 100644 --- a/presto-main/src/main/java/com/facebook/presto/Session.java +++ b/presto-main/src/main/java/com/facebook/presto/Session.java @@ -17,6 +17,8 @@ import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.cost.PlanCostEstimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; @@ -25,6 +27,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.Identity; @@ -99,6 +102,8 @@ public final class Session private final OptimizerInformationCollector optimizerInformationCollector = new OptimizerInformationCollector(); private final OptimizerResultCollector optimizerResultCollector = new OptimizerResultCollector(); private final CTEInformationCollector cteInformationCollector = new CTEInformationCollector(); + private final Map planNodeStatsMap = new HashMap<>(); + private final Map planNodeCostMap = new HashMap<>(); public Session( QueryId queryId, @@ -337,6 +342,16 @@ public CTEInformationCollector getCteInformationCollector() return cteInformationCollector; } + public Map getPlanNodeStatsMap() + { + return planNodeStatsMap; + } + + public Map getPlanNodeCostMap() + { + return planNodeCostMap; + } + public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl) { requireNonNull(transactionId, "transactionId is null"); diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 641ba55ee014..067997132c1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -324,6 +324,7 @@ public final class SystemSessionProperties public static final String SKIP_HASH_GENERATION_FOR_JOIN_WITH_TABLE_SCAN_INPUT = "skip_hash_generation_for_join_with_table_scan_input"; public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters"; public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression"; + public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled"; @@ -1915,6 +1916,12 @@ public SystemSessionProperties( "Rewrite left join with is null check to semi join", featuresConfig.isRewriteExpressionWithConstantVariable(), false), + booleanProperty( + PRINT_ESTIMATED_STATS_FROM_CACHE, + "When printing estimated plan stats after optimization is complete, such as in an EXPLAIN query or for logging in a QueryCompletedEvent, " + + "get stats from a cache that was populated during query optimization rather than recalculating the stats on the final plan.", + featuresConfig.isPrintEstimatedStatsFromCache(), + false), new PropertyMetadata<>( DEFAULT_VIEW_SECURITY_MODE, format("Set default view security mode. Options are: %s", @@ -3218,4 +3225,9 @@ public static boolean isJoinPrefilterEnabled(Session session) { return session.getSystemProperty(JOIN_PREFILTER_BUILD_SIDE, Boolean.class); } + + public static boolean isPrintEstimatedStatsFromCacheEnabled(Session session) + { + return session.getSystemProperty(PRINT_ESTIMATED_STATS_FROM_CACHE, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java index 400ca281a3bc..e77bf3b68ae3 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java @@ -64,16 +64,20 @@ public PlanCostEstimate getCost(PlanNode node) try { if (node instanceof GroupReference) { - return getGroupCost((GroupReference) node); + PlanCostEstimate result = getGroupCost((GroupReference) node); + session.getPlanNodeCostMap().put(node.getId(), result); + return result; } PlanCostEstimate cost = cache.get(node); if (cost != null) { + session.getPlanNodeCostMap().put(node.getId(), cost); return cost; } cost = calculateCost(node); verify(cache.put(node, cost) == null, "Cost already set"); + session.getPlanNodeCostMap().put(node.getId(), cost); return cost; } catch (RuntimeException e) { diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsProvider.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsProvider.java index 366de73026f2..413121e2616b 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsProvider.java @@ -69,16 +69,20 @@ public PlanNodeStatsEstimate getStats(PlanNode node) try { if (node instanceof GroupReference) { - return getGroupStats((GroupReference) node); + PlanNodeStatsEstimate result = getGroupStats((GroupReference) node); + session.getPlanNodeStatsMap().put(node.getId(), result); + return result; } PlanNodeStatsEstimate stats = cache.get(node); if (stats != null) { + session.getPlanNodeStatsMap().put(node.getId(), stats); return stats; } stats = statsCalculator.calculateStats(node, this, lookup, session, types); verify(cache.put(node, stats) == null, "Stats already set"); + session.getPlanNodeStatsMap().put(node.getId(), stats); return stats; } catch (RuntimeException e) { diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsAndCosts.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsAndCosts.java index 5577404b9140..c238abe4d139 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatsAndCosts.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsAndCosts.java @@ -14,6 +14,7 @@ package com.facebook.presto.cost; +import com.facebook.presto.Session; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.fasterxml.jackson.annotation.JsonCreator; @@ -24,6 +25,7 @@ import java.util.Map; import java.util.Objects; +import static com.facebook.presto.SystemSessionProperties.isPrintEstimatedStatsFromCacheEnabled; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @@ -77,15 +79,16 @@ public StatsAndCosts getForSubplan(PlanNode root) return new StatsAndCosts(filteredStats.build(), filteredCosts.build()); } - public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider) + public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider, Session session) { Iterable planIterator = Traverser.forTree(PlanNode::getSources) .depthFirstPreOrder(root); ImmutableMap.Builder stats = ImmutableMap.builder(); ImmutableMap.Builder costs = ImmutableMap.builder(); + boolean printStatsFromCache = isPrintEstimatedStatsFromCacheEnabled(session); for (PlanNode node : planIterator) { - stats.put(node.getId(), statsProvider.getStats(node)); - costs.put(node.getId(), costProvider.getCost(node)); + stats.put(node.getId(), printStatsFromCache ? session.getPlanNodeStatsMap().getOrDefault(node.getId(), PlanNodeStatsEstimate.unknown()) : statsProvider.getStats(node)); + costs.put(node.getId(), printStatsFromCache ? session.getPlanNodeCostMap().getOrDefault(node.getId(), PlanCostEstimate.unknown()) : costProvider.getCost(node)); } return new StatsAndCosts(stats.build(), costs.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/Optimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/Optimizer.java index 992cf84b35f1..317722e5a7bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/Optimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/Optimizer.java @@ -152,7 +152,7 @@ private StatsAndCosts computeStats(PlanNode root, TypeProvider types) (node instanceof JoinNode) || (node instanceof SemiJoinNode)).matches()) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session); - return StatsAndCosts.create(root, statsProvider, costProvider); + return StatsAndCosts.create(root, statsProvider, costProvider, session); } return StatsAndCosts.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 8d13f68ffa6a..c4cf91b70999 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -307,6 +307,7 @@ public class FeaturesConfig private long kHyperLogLogAggregationGroupNumberLimit; private boolean limitNumberOfGroupsForKHyperLogLogAggregations = true; private boolean generateDomainFilters; + private boolean printEstimatedStatsFromCache; private CreateView.Security defaultViewSecurityMode = DEFINER; public enum PartitioningPrecisionStrategy @@ -3101,4 +3102,17 @@ public FeaturesConfig setDefaultViewSecurityMode(CreateView.Security securityMod this.defaultViewSecurityMode = securityMode; return this; } + + public boolean isPrintEstimatedStatsFromCache() + { + return this.printEstimatedStatsFromCache; + } + + @Config("optimizer.print-estimated-stats-from-cache") + @ConfigDescription("In the end of query optimization, print the estimation stats from cache populated during optimization instead of calculating from ground") + public FeaturesConfig setPrintEstimatedStatsFromCache(boolean printEstimatedStatsFromCache) + { + this.printEstimatedStatsFromCache = printEstimatedStatsFromCache; + return this; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 5cebf55ee29a..e57b035a0385 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -688,7 +688,7 @@ private CostAssertionBuilder assertCostSingleStageFragmentedPlan( StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider); CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session); // Explicitly generate the statsAndCosts, bypass fragment generation and sanity checks for mock plans. - StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider).getForSubplan(node); + StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider, session).getForSubplan(node); return new CostAssertionBuilder(statsAndCosts.getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown())); } @@ -807,7 +807,7 @@ private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalcula TypeProvider typeProvider = TypeProvider.copyOf(types); StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider); CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session); - SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); + SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider, session))); return subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ce9d8a8673b5..6a195ce96f25 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -270,7 +270,8 @@ public void testDefaults() .setDefaultWriterReplicationCoefficient(3.0) .setDefaultViewSecurityMode(DEFINER) .setCteHeuristicReplicationThreshold(4) - .setLegacyJsonCast(true)); + .setLegacyJsonCast(true) + .setPrintEstimatedStatsFromCache(false)); } @Test @@ -485,6 +486,7 @@ public void testExplicitPropertyMappings() .put("optimizer.default-writer-replication-coefficient", "5.0") .put("default-view-security-mode", INVOKER.name()) .put("cte-heuristic-replication-threshold", "2") + .put("optimizer.print-estimated-stats-from-cache", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -696,7 +698,8 @@ public void testExplicitPropertyMappings() .setDefaultWriterReplicationCoefficient(5.0) .setDefaultViewSecurityMode(INVOKER) .setCteHeuristicReplicationThreshold(2) - .setLegacyJsonCast(false); + .setLegacyJsonCast(false) + .setPrintEstimatedStatsFromCache(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index d041c94ff5a4..d5269a7f0473 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -53,7 +53,7 @@ public static void assertPlan(Session session, Metadata metadata, StatsProvider // TODO (Issue #13231) add back printing unresolved plan once we have no need to translate OriginalExpression to RowExpression if (!matches.isMatch()) { PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup); - String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown()), metadata.getFunctionAndTypeManager(), session, 0); + String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown(), session), metadata.getFunctionAndTypeManager(), session, 0); throw new AssertionError(format( "Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n]", pattern, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 1a6895460265..8d8f9e28fb5f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -240,7 +240,7 @@ private String formatPlan(PlanNode plan, TypeProvider types) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session); - return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session))); + return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider, session), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session))); } private T inTransaction(Function transactionSessionConsumer)