Skip to content

Commit

Permalink
Record estimation stats during query optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu committed May 17, 2024
1 parent 8e34250 commit 0911789
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 12 deletions.
15 changes: 15 additions & 0 deletions presto-main/src/main/java/com/facebook/presto/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
Expand All @@ -25,6 +27,7 @@
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AccessControlContext;
import com.facebook.presto.spi.security.Identity;
Expand Down Expand Up @@ -99,6 +102,8 @@ public final class Session
private final OptimizerInformationCollector optimizerInformationCollector = new OptimizerInformationCollector();
private final OptimizerResultCollector optimizerResultCollector = new OptimizerResultCollector();
private final CTEInformationCollector cteInformationCollector = new CTEInformationCollector();
private final Map<PlanNodeId, PlanNodeStatsEstimate> planNodeEstimateMap = new HashMap<>();
private final Map<PlanNodeId, PlanCostEstimate> planNodeCostMap = new HashMap<>();

public Session(
QueryId queryId,
Expand Down Expand Up @@ -337,6 +342,16 @@ public CTEInformationCollector getCteInformationCollector()
return cteInformationCollector;
}

public Map<PlanNodeId, PlanNodeStatsEstimate> getPlanNodeEstimateMap()
{
return planNodeEstimateMap;
}

public Map<PlanNodeId, PlanCostEstimate> getPlanNodeCostMap()
{
return planNodeCostMap;
}

public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl)
{
requireNonNull(transactionId, "transactionId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ public final class SystemSessionProperties
public static final String SKIP_HASH_GENERATION_FOR_JOIN_WITH_TABLE_SCAN_INPUT = "skip_hash_generation_for_join_with_table_scan_input";
public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters";
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATION_STATS_FROM_CACHE = "print_estimation_stats_from_cache";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1914,6 +1915,11 @@ public SystemSessionProperties(
"Rewrite left join with is null check to semi join",
featuresConfig.isRewriteExpressionWithConstantVariable(),
false),
booleanProperty(
PRINT_ESTIMATION_STATS_FROM_CACHE,
"In the end of query optimization, print the estimation stats from cache populated during optimization instead of calculating from ground",
featuresConfig.isRewriteExpressionWithConstantVariable(),
false),
new PropertyMetadata<>(
DEFAULT_VIEW_SECURITY_MODE,
format("Set default view security mode. Options are: %s",
Expand Down Expand Up @@ -3207,4 +3213,9 @@ public static CreateView.Security getDefaultViewSecurityMode(Session session)
{
return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, CreateView.Security.class);
}

public static boolean isPrintEstimationStatsFromCacheEnabled(Session session)
{
return session.getSystemProperty(PRINT_ESTIMATION_STATS_FROM_CACHE, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,20 @@ public PlanCostEstimate getCost(PlanNode node)

try {
if (node instanceof GroupReference) {
return getGroupCost((GroupReference) node);
PlanCostEstimate result = getGroupCost((GroupReference) node);
session.getPlanNodeCostMap().put(node.getId(), result);
return result;
}

PlanCostEstimate cost = cache.get(node);
if (cost != null) {
session.getPlanNodeCostMap().put(node.getId(), cost);
return cost;
}

cost = calculateCost(node);
verify(cache.put(node, cost) == null, "Cost already set");
session.getPlanNodeCostMap().put(node.getId(), cost);
return cost;
}
catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,20 @@ public PlanNodeStatsEstimate getStats(PlanNode node)

try {
if (node instanceof GroupReference) {
return getGroupStats((GroupReference) node);
PlanNodeStatsEstimate result = getGroupStats((GroupReference) node);
session.getPlanNodeEstimateMap().put(node.getId(), result);
return result;
}

PlanNodeStatsEstimate stats = cache.get(node);
if (stats != null) {
session.getPlanNodeEstimateMap().put(node.getId(), stats);
return stats;
}

stats = statsCalculator.calculateStats(node, this, lookup, session, types);
verify(cache.put(node, stats) == null, "Stats already set");
session.getPlanNodeEstimateMap().put(node.getId(), stats);
return stats;
}
catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.fasterxml.jackson.annotation.JsonCreator;
Expand All @@ -24,6 +25,7 @@
import java.util.Map;
import java.util.Objects;

import static com.facebook.presto.SystemSessionProperties.isPrintEstimationStatsFromCacheEnabled;
import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -77,15 +79,16 @@ public StatsAndCosts getForSubplan(PlanNode root)
return new StatsAndCosts(filteredStats.build(), filteredCosts.build());
}

public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider)
public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider, Session session, boolean useCache)
{
Iterable<PlanNode> planIterator = Traverser.forTree(PlanNode::getSources)
.depthFirstPreOrder(root);
ImmutableMap.Builder<PlanNodeId, PlanNodeStatsEstimate> stats = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, PlanCostEstimate> costs = ImmutableMap.builder();
boolean printStatsFromCache = isPrintEstimationStatsFromCacheEnabled(session);
for (PlanNode node : planIterator) {
stats.put(node.getId(), statsProvider.getStats(node));
costs.put(node.getId(), costProvider.getCost(node));
stats.put(node.getId(), useCache && printStatsFromCache ? session.getPlanNodeEstimateMap().getOrDefault(node.getId(), PlanNodeStatsEstimate.unknown()) : statsProvider.getStats(node));
costs.put(node.getId(), useCache && printStatsFromCache ? session.getPlanNodeCostMap().getOrDefault(node.getId(), PlanCostEstimate.unknown()) : costProvider.getCost(node));
}
return new StatsAndCosts(stats.build(), costs.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private StatsAndCosts computeStats(PlanNode root, TypeProvider types)
(node instanceof JoinNode) || (node instanceof SemiJoinNode)).matches()) {
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session);
return StatsAndCosts.create(root, statsProvider, costProvider);
return StatsAndCosts.create(root, statsProvider, costProvider, session, true);
}
return StatsAndCosts.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ public class FeaturesConfig
private long kHyperLogLogAggregationGroupNumberLimit;
private boolean limitNumberOfGroupsForKHyperLogLogAggregations = true;
private boolean generateDomainFilters;
private boolean printEstimationStatsFromCache = true;
private CreateView.Security defaultViewSecurityMode = DEFINER;

public enum PartitioningPrecisionStrategy
Expand Down Expand Up @@ -3101,4 +3102,17 @@ public FeaturesConfig setDefaultViewSecurityMode(CreateView.Security securityMod
this.defaultViewSecurityMode = securityMode;
return this;
}

public boolean isPrintEstimationStatsFromCache()
{
return this.printEstimationStatsFromCache;
}

@Config("optimizer.print-estimation-stats-from-cache")
@ConfigDescription("In the end of query optimization, print the estimation stats from cache populated during optimization instead of calculating from ground")
public FeaturesConfig setPrintEstimationStatsFromCache(boolean printEstimationStatsFromCache)
{
this.printEstimationStatsFromCache = printEstimationStatsFromCache;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ private CostAssertionBuilder assertCostSingleStageFragmentedPlan(
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider);
CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session);
// Explicitly generate the statsAndCosts, bypass fragment generation and sanity checks for mock plans.
StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider).getForSubplan(node);
StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider, session, false).getForSubplan(node);
return new CostAssertionBuilder(statsAndCosts.getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()));
}

Expand Down Expand Up @@ -807,7 +807,7 @@ private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalcula
TypeProvider typeProvider = TypeProvider.copyOf(types);
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider);
CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session);
SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider)));
SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider, session, false)));
return subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ public void testDefaults()
.setDefaultWriterReplicationCoefficient(3.0)
.setDefaultViewSecurityMode(DEFINER)
.setCteHeuristicReplicationThreshold(4)
.setLegacyJsonCast(true));
.setLegacyJsonCast(true)
.setPrintEstimationStatsFromCache(true));
}

@Test
Expand Down Expand Up @@ -485,6 +486,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.default-writer-replication-coefficient", "5.0")
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
.put("optimizer.print-estimation-stats-from-cache", "false")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -696,7 +698,8 @@ public void testExplicitPropertyMappings()
.setDefaultWriterReplicationCoefficient(5.0)
.setDefaultViewSecurityMode(INVOKER)
.setCteHeuristicReplicationThreshold(2)
.setLegacyJsonCast(false);
.setLegacyJsonCast(false)
.setPrintEstimationStatsFromCache(false);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static void assertPlan(Session session, Metadata metadata, StatsProvider
// TODO (Issue #13231) add back printing unresolved plan once we have no need to translate OriginalExpression to RowExpression
if (!matches.isMatch()) {
PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup);
String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown()), metadata.getFunctionAndTypeManager(), session, 0);
String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown(), session, false), metadata.getFunctionAndTypeManager(), session, 0);
throw new AssertionError(format(
"Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n]",
pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private String formatPlan(PlanNode plan, TypeProvider types)
{
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session);
return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session)));
return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider, session, false), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session)));
}

private <T> T inTransaction(Function<Session, T> transactionSessionConsumer)
Expand Down

0 comments on commit 0911789

Please sign in to comment.