Skip to content

Commit

Permalink
Enhance FunctionCallRewriter
Browse files Browse the repository at this point in the history
Enhance FunctionCallRewriter to support function substitution for
multi-signature functions. Change to use Multimap to support one
function name mapped to multiple function substitutes. Select
the first function substitute where the original function
declaration matches.

Inherit the filter field and window fields of the original function
by the aggregate or window functions in the substitute. Introduce
the FunctionAndTypeManager to FunctionCallRewriter to identify
aggregate and window functions.
  • Loading branch information
gggrace14 committed May 25, 2024
1 parent f16032e commit 56c61e4
Show file tree
Hide file tree
Showing 7 changed files with 595 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,36 @@
import com.facebook.presto.sql.tree.Statement;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public class QueryObjectBundle
extends QueryBundle
{
private final QualifiedName objectName;
private final Optional<String> rewrittenFunctionCalls;

public QueryObjectBundle(
QualifiedName objectName,
List<Statement> setupQueries,
Statement query,
List<Statement> teardownQueries,
ClusterType cluster)
ClusterType cluster,
Optional<String> rewrittenFunctionCalls)
{
super(setupQueries, query, teardownQueries, cluster);
this.objectName = requireNonNull(objectName, "objectName is null");
this.rewrittenFunctionCalls = requireNonNull(rewrittenFunctionCalls, "rewrittenFunctionCalls is null");
}

public QualifiedName getObjectName()
{
return objectName;
}

public Optional<String> getRewrittenFunctionCalls()
{
return rewrittenFunctionCalls;
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import com.facebook.presto.verifier.prestoaction.PrestoAction.ResultSetConverter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import org.intellij.lang.annotations.Language;

import java.sql.ResultSetMetaData;
Expand Down Expand Up @@ -78,6 +80,7 @@
import static com.facebook.presto.verifier.framework.VerifierUtil.PARSING_OPTIONS;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnNames;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnTypes;
import static com.facebook.presto.verifier.rewrite.FunctionCallRewriter.FunctionCallSubstitute;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
Expand All @@ -104,7 +107,7 @@ public QueryRewriter(
Map<ClusterType, QualifiedName> tablePrefixes,
Map<ClusterType, List<Property>> tableProperties)
{
this(sqlParser, typeManager, prestoAction, tablePrefixes, tableProperties, Optional.empty());
this(sqlParser, typeManager, prestoAction, tablePrefixes, tableProperties, ImmutableMultimap.of());
}

public QueryRewriter(
Expand All @@ -113,15 +116,14 @@ public QueryRewriter(
PrestoAction prestoAction,
Map<ClusterType, QualifiedName> tablePrefixes,
Map<ClusterType, List<Property>> tableProperties,
Optional<String> nonDeterministicFunctionSubstitutes)
Multimap<String, FunctionCallSubstitute> functionSubstitutes)
{
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.prestoAction = requireNonNull(prestoAction, "prestoAction is null");
this.prefixes = ImmutableMap.copyOf(tablePrefixes);
this.tableProperties = ImmutableMap.copyOf(tableProperties);
this.functionCallRewriter =
requireNonNull(nonDeterministicFunctionSubstitutes, "nonDeterministicFunctionSubstitutes is null").map(functionSubstitutes -> FunctionCallRewriter.getInstance(functionSubstitutes));
this.functionCallRewriter = FunctionCallRewriter.getInstance(functionSubstitutes, typeManager);
}

public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType clusterType)
Expand All @@ -135,8 +137,11 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
CreateTableAsSelect createTableAsSelect = (CreateTableAsSelect) statement;
QualifiedName temporaryTableName = generateTemporaryName(Optional.of(createTableAsSelect.getName()), prefix);
Query createQuery = createTableAsSelect.getQuery();
Optional<String> functionSubstitutions = Optional.empty();
if (functionCallRewriter.isPresent()) {
createQuery = (Query) functionCallRewriter.get().rewrite(createQuery);
FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(createQuery);
createQuery = (Query) rewriterResult.getRewrittenNode();
functionSubstitutions = rewriterResult.getSubstitutions();
}
return new QueryObjectBundle(
temporaryTableName,
Expand All @@ -150,15 +155,19 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
createTableAsSelect.getColumnAliases(),
createTableAsSelect.getComment()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
clusterType,
functionSubstitutions);
}
if (statement instanceof Insert) {
Insert insert = (Insert) statement;
QualifiedName originalTableName = insert.getTarget();
QualifiedName temporaryTableName = generateTemporaryName(Optional.of(originalTableName), prefix);
Query insertQuery = insert.getQuery();
Optional<String> functionSubstitutions = Optional.empty();
if (functionCallRewriter.isPresent()) {
insertQuery = (Query) functionCallRewriter.get().rewrite(insertQuery);
FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(insertQuery);
insertQuery = (Query) rewriterResult.getRewrittenNode();
functionSubstitutions = rewriterResult.getSubstitutions();
}
return new QueryObjectBundle(
temporaryTableName,
Expand All @@ -174,13 +183,17 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
insert.getColumns(),
insertQuery),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
clusterType,
functionSubstitutions);
}
if (statement instanceof Query) {
QualifiedName temporaryTableName = generateTemporaryName(Optional.empty(), prefix);
Query queryBody = (Query) statement;
Optional<String> functionSubstitutions = Optional.empty();
if (functionCallRewriter.isPresent()) {
queryBody = (Query) functionCallRewriter.get().rewrite(queryBody);
FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(queryBody);
queryBody = (Query) rewriterResult.getRewrittenNode();
functionSubstitutions = rewriterResult.getSubstitutions();
}
ResultSetMetaData metadata = getResultMetadata(queryBody);
List<Identifier> columnAliases = generateStorageColumnAliases(metadata);
Expand All @@ -198,7 +211,8 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
Optional.of(columnAliases),
Optional.empty()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
clusterType,
functionSubstitutions);
}
if (statement instanceof CreateView) {
CreateView createView = (CreateView) statement;
Expand Down Expand Up @@ -232,7 +246,8 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
createView.isReplace(),
createView.getSecurity()),
ImmutableList.of(new DropView(temporaryViewName, true)),
clusterType);
clusterType,
Optional.empty());
}
if (statement instanceof CreateTable) {
CreateTable createTable = (CreateTable) statement;
Expand All @@ -247,7 +262,8 @@ public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType
applyPropertyOverride(createTable.getProperties(), properties),
createTable.getComment()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
clusterType,
Optional.empty());
}

throw new IllegalStateException(format("Unsupported query type: %s", statement.getClass()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.sql.parser.IdentifierSymbol.AT_SIGN;
Expand All @@ -69,7 +70,8 @@ public class VerifierTestUtil
"INSERT INTO test SELECT * FROM source",
ParsingOptions.builder().setDecimalLiteralTreatment(AS_DOUBLE).build()),
ImmutableList.of(),
CONTROL);
CONTROL,
Optional.empty());

private static final MySqlOptions MY_SQL_OPTIONS = MySqlOptions.builder()
.setCommandTimeout(new Duration(90, SECONDS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.parser.IdentifierSymbol.AT_SIGN;
import static com.facebook.presto.sql.parser.IdentifierSymbol.COLON;
Expand Down Expand Up @@ -65,6 +67,7 @@ private static QueryObjectBundle createBundle(String query)
ImmutableList.of(),
sqlParser.createStatement(query, PARSING_OPTIONS),
ImmutableList.of(),
CONTROL);
CONTROL,
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public <R> QueryResult<R> execute(Statement statement, QueryStage queryStage, Re
"INSERT INTO test SELECT * FROM source",
ParsingOptions.builder().setDecimalLiteralTreatment(AS_DOUBLE).build()),
ImmutableList.of(),
TEST);
TEST,
Optional.empty());
private static final QueryException HIVE_TOO_MANY_OPEN_PARTITIONS_EXCEPTION = new PrestoQueryException(
new RuntimeException(),
false,
Expand Down
Loading

0 comments on commit 56c61e4

Please sign in to comment.