Skip to content

Commit

Permalink
Fix analyzer for lambda in aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu committed Apr 19, 2024
1 parent d2b2ae9 commit 71ad56f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,14 @@ public static boolean isConstant(Expression expression)
tempExpression = ((Cast) tempExpression).getExpression();
}

if (tempExpression instanceof Literal || tempExpression instanceof ArrayConstructor) {
if (tempExpression instanceof Literal) {
return true;
}

if (tempExpression instanceof ArrayConstructor) {
return ((ArrayConstructor) tempExpression).getValues().stream().allMatch(ExpressionTreeUtils::isConstant);
}

// ROW an MAP are special so we explicitly do that here.
if (tempExpression instanceof Row) {
return (((Row) tempExpression).getItems().stream().allMatch(ExpressionTreeUtils::isConstant));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static com.facebook.presto.common.type.TypeUtils.isEnumType;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isConstant;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -119,19 +120,27 @@ public Expression rewrite(Expression expression)
Expression mapped = translateNamesToSymbols(expression);

// then rewrite subexpressions in terms of the current mappings
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Boolean>()
{
@Override
public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
public Expression rewriteExpression(Expression node, Boolean context, ExpressionTreeRewriter<Boolean> treeRewriter)
{
if (expressionToVariables.containsKey(node)) {
// Do not rewrite if node is constant and within a lambda expression
if (expressionToVariables.containsKey(node) && !((context.equals(Boolean.TRUE) && isConstant(node)))) {
return new SymbolReference(expression.getLocation(), expressionToVariables.get(node).getName());
}

Expression translated = expressionToExpressions.getOrDefault(node, node);
return treeRewriter.defaultRewrite(translated, context);
}
}, mapped);

@Override
public Expression rewriteLambdaExpression(LambdaExpression node, Boolean context, ExpressionTreeRewriter<Boolean> treeRewriter)
{
Expression result = super.rewriteLambdaExpression(node, true, treeRewriter);
return result;
}
}, mapped, false);
}

public void put(Expression expression, VariableReferenceExpression variable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7520,4 +7520,14 @@ public void testGuardConstraintFramework()
assertQuery("select orderkey from (select * from (select * from orders where 1=0)) group by rollup(orderkey)",
"values (null)");
}

@Test
public void testLambdaInAggregation()
{
assertQuery("SELECT id, reduce_agg(value, 0, (a, b) -> a + b+0, (a, b) -> a + b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", "values (1, 9), (2, 90)");
assertQuery("SELECT id, reduce_agg(value, 's', (a, b) -> concat(a, b, 's'), (a, b) -> concat(a, b, 's')) FROM ( VALUES (1, '2'), (1, '3'), (1, '4'), (2, '20'), (2, '30'), (2, '40') ) AS t(id, value) GROUP BY id",
"values (1, 's2s3s4s'), (2, 's20s30s40s')");
assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id",
".*REDUCE_AGG only supports non-NULL literal as the initial value.*");
}
}

0 comments on commit 71ad56f

Please sign in to comment.