Skip to content

Commit

Permalink
Remove Metadata from RowExpressionInterpreter
Browse files Browse the repository at this point in the history
The only reference to the Metadata in RowExpressionInterpreter is to
retrieve the FunctionAndTypeManager.  However, the FunctionAndTypeManager
actually doesn't even belong on the Metadata class and can be injected
where needed.  This refactoring reduces the call sites to
Metadata#getFunctionAndTypeManager.
  • Loading branch information
tdcmeehan committed May 31, 2024
1 parent 7659fe4 commit 3cd2744
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ

private RowExpression simplifyExpression(ConnectorSession session, RowExpression predicate)
{
RowExpressionInterpreter interpreter = new RowExpressionInterpreter(predicate, metadata, session, OPTIMIZED);
RowExpressionInterpreter interpreter = new RowExpressionInterpreter(predicate, metadata.getFunctionAndTypeManager(), session, OPTIMIZED);
Object value = interpreter.optimize();

if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private List<Object> getVariableValues(ValuesNode valuesNode, int symbolId, Sess
}
return valuesNode.getRows().stream()
.map(row -> row.get(symbolId))
.map(rowExpression -> evaluateConstantRowExpression(rowExpression, metadata, session.toConnectorSession()))
.map(rowExpression -> evaluateConstantRowExpression(rowExpression, metadata.getFunctionAndTypeManager(), session.toConnectorSession()))
.collect(toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ else if (locality.equals(LOCAL)) {
private RowExpression bindChannels(RowExpression expression, Map<VariableReferenceExpression, Integer> sourceLayout)
{
Type type = expression.getType();
Object value = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), OPTIMIZED).optimize();
Object value = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session.toConnectorSession(), OPTIMIZED).optimize();
if (value instanceof RowExpression) {
RowExpression optimized = (RowExpression) value;
// building channel info
Expand Down Expand Up @@ -1643,7 +1643,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext
pageBuilder.declarePosition();
for (int i = 0; i < row.size(); i++) {
// evaluate the literal value
Object result = rowExpressionInterpreter(row.get(i), metadata, context.getSession().toConnectorSession()).evaluate();
Object result = rowExpressionInterpreter(row.get(i), metadata.getFunctionAndTypeManager(), context.getSession().toConnectorSession()).evaluate();
writeNativeValue(outputTypes.get(i), pageBuilder.getBlockBuilder(i), result);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ public class RowExpressionInterpreter
{
private static final long MAX_SERIALIZABLE_OBJECT_SIZE = 1000;
private final RowExpression expression;
private final Metadata metadata;
private final ConnectorSession session;
private final Level optimizationLevel;
private final InterpretedFunctionInvoker functionInvoker;
Expand All @@ -132,29 +131,33 @@ public class RowExpressionInterpreter

private final Visitor visitor;

public static Object evaluateConstantRowExpression(RowExpression expression, Metadata metadata, ConnectorSession session)
public static Object evaluateConstantRowExpression(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session)
{
// evaluate the expression
Object result = new RowExpressionInterpreter(expression, metadata, session, EVALUATED).evaluate();
Object result = new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED).evaluate();
verify(!(result instanceof RowExpression), "RowExpression interpreter returned an unresolved expression");
return result;
}

public static RowExpressionInterpreter rowExpressionInterpreter(RowExpression expression, Metadata metadata, ConnectorSession session)
public static RowExpressionInterpreter rowExpressionInterpreter(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session)
{
return new RowExpressionInterpreter(expression, metadata, session, EVALUATED);
return new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED);
}

public RowExpressionInterpreter(RowExpression expression, Metadata metadata, ConnectorSession session, Level optimizationLevel)
{
this(expression, metadata.getFunctionAndTypeManager(), session, optimizationLevel);
}
public RowExpressionInterpreter(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session, Level optimizationLevel)
{
this.expression = requireNonNull(expression, "expression is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.session = requireNonNull(session, "session is null");
this.optimizationLevel = optimizationLevel;
this.functionInvoker = new InterpretedFunctionInvoker(metadata.getFunctionAndTypeManager());
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());
this.resolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
this.functionAndTypeManager = metadata.getFunctionAndTypeManager();
requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.functionInvoker = new InterpretedFunctionInvoker(functionAndTypeManager);
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
this.resolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
this.functionAndTypeManager = functionAndTypeManager;

this.visitor = new Visitor();
}
Expand Down Expand Up @@ -221,7 +224,7 @@ public Object visitCall(CallExpression node, Object context)
}

FunctionHandle functionHandle = node.getFunctionHandle();
FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(node.getFunctionHandle());
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(node.getFunctionHandle());
if (!functionMetadata.isCalledOnNullInput()) {
for (Object value : argumentValues) {
if (value == null) {
Expand Down Expand Up @@ -286,11 +289,11 @@ else if (implementationType.equals(JAVA)) {
RowExpression function = getSqlFunctionRowExpression(
functionMetadata,
functionImplementation,
metadata.getFunctionAndTypeManager(),
functionAndTypeManager,
session.getSqlFunctionProperties(),
session.getSessionFunctions(),
node.getArguments());
RowExpressionInterpreter rowExpressionInterpreter = new RowExpressionInterpreter(function, metadata, session, optimizationLevel);
RowExpressionInterpreter rowExpressionInterpreter = new RowExpressionInterpreter(function, functionAndTypeManager, session, optimizationLevel);
if (optimizationLevel.ordinal() >= EVALUATED.ordinal()) {
value = rowExpressionInterpreter.evaluate();
}
Expand Down Expand Up @@ -386,9 +389,9 @@ else if (Boolean.TRUE.equals(condition)) {

Type leftType = node.getArguments().get(0).getType();
Type rightType = node.getArguments().get(1).getType();
Type commonType = metadata.getFunctionAndTypeManager().getCommonSuperType(leftType, rightType).get();
FunctionHandle firstCast = metadata.getFunctionAndTypeManager().lookupCast(CAST, leftType, commonType);
FunctionHandle secondCast = metadata.getFunctionAndTypeManager().lookupCast(CAST, rightType, commonType);
Type commonType = functionAndTypeManager.getCommonSuperType(leftType, rightType).get();
FunctionHandle firstCast = functionAndTypeManager.lookupCast(CAST, leftType, commonType);
FunctionHandle secondCast = functionAndTypeManager.lookupCast(CAST, rightType, commonType);

// cast(first as <common type>) == cast(second as <common type>)
boolean equal = Boolean.TRUE.equals(invokeOperator(
Expand Down Expand Up @@ -717,16 +720,16 @@ private RowExpression createFailureFunction(RuntimeException exception, Type typ
requireNonNull(exception, "Exception is null");

String failureInfo = JsonCodec.jsonCodec(FailureInfo.class).toJson(Failures.toFailure(exception).toFailureInfo());
FunctionHandle jsonParse = metadata.getFunctionAndTypeManager().lookupFunction("json_parse", fromTypes(VARCHAR));
FunctionHandle jsonParse = functionAndTypeManager.lookupFunction("json_parse", fromTypes(VARCHAR));
Object json = functionInvoker.invoke(jsonParse, session.getSqlFunctionProperties(), utf8Slice(failureInfo));
FunctionHandle cast = metadata.getFunctionAndTypeManager().lookupCast(CAST, UNKNOWN, type);
FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, UNKNOWN, type);
if (exception instanceof PrestoException) {
long errorCode = ((PrestoException) exception).getErrorCode().getCode();
FunctionHandle failureFunction = metadata.getFunctionAndTypeManager().lookupFunction("fail", fromTypes(INTEGER, JSON));
FunctionHandle failureFunction = functionAndTypeManager.lookupFunction("fail", fromTypes(INTEGER, JSON));
return call(CAST.name(), cast, type, call("fail", failureFunction, UNKNOWN, constant(errorCode, INTEGER), LiteralEncoder.toRowExpression(json, JSON)));
}

FunctionHandle failureFunction = metadata.getFunctionAndTypeManager().lookupFunction("fail", fromTypes(JSON));
FunctionHandle failureFunction = functionAndTypeManager.lookupFunction("fail", fromTypes(JSON));
return call(CAST.name(), cast, type, call("fail", failureFunction, UNKNOWN, LiteralEncoder.toRowExpression(json, JSON)));
}

Expand All @@ -742,7 +745,7 @@ private boolean hasUnresolvedValue(List<Object> values)

private Object invokeOperator(OperatorType operatorType, List<? extends Type> argumentTypes, List<Object> argumentValues)
{
FunctionHandle operatorHandle = metadata.getFunctionAndTypeManager().resolveOperator(operatorType, fromTypes(argumentTypes));
FunctionHandle operatorHandle = functionAndTypeManager.resolveOperator(operatorType, fromTypes(argumentTypes));
return functionInvoker.invoke(operatorHandle, session.getSqlFunctionProperties(), argumentValues);
}

Expand Down Expand Up @@ -863,15 +866,15 @@ private SpecialCallResult tryHandleCast(CallExpression callExpression, List<Obje
return changed(call(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), toRowExpression(value, source)));
}

if (metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(sourceType, targetType)) {
if (functionAndTypeManager.isTypeOnlyCoercion(sourceType, targetType)) {
return changed(value);
}
return notChanged();
}

private SpecialCallResult tryHandleLike(CallExpression callExpression, List<Object> argumentValues, List<Type> argumentTypes, Object context)
{
FunctionResolution resolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
FunctionResolution resolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
checkArgument(resolution.isLikeFunction(callExpression.getFunctionHandle()));
checkArgument(callExpression.getArguments().size() == 2);
RowExpression likePatternExpression = callExpression.getArguments().get(1);
Expand Down Expand Up @@ -935,20 +938,20 @@ private SpecialCallResult tryHandleLike(CallExpression callExpression, List<Obje
Slice unescapedPattern = unescapeLiteralLikePattern((Slice) nonCompiledPattern, (Slice) escape);
Type valueType = argumentTypes.get(0);
Type patternType = createVarcharType(unescapedPattern.length());
Optional<Type> commonSuperType = metadata.getFunctionAndTypeManager().getCommonSuperType(valueType, patternType);
Optional<Type> commonSuperType = functionAndTypeManager.getCommonSuperType(valueType, patternType);
checkArgument(commonSuperType.isPresent(), "Missing super type when optimizing %s", callExpression);
RowExpression valueExpression = LiteralEncoder.toRowExpression(callExpression.getSourceLocation(), value, valueType);
RowExpression patternExpression = LiteralEncoder.toRowExpression(callExpression.getSourceLocation(), unescapedPattern, patternType);
Type superType = commonSuperType.get();
if (!valueType.equals(superType)) {
FunctionHandle cast = metadata.getFunctionAndTypeManager().lookupCast(CAST, valueType, superType);
FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, valueType, superType);
valueExpression = call(CAST.name(), cast, superType, valueExpression);
}
if (!patternType.equals(superType)) {
FunctionHandle cast = metadata.getFunctionAndTypeManager().lookupCast(CAST, patternType, superType);
FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, patternType, superType);
patternExpression = call(CAST.name(), cast, superType, patternExpression);
}
FunctionHandle equal = metadata.getFunctionAndTypeManager().resolveOperator(EQUAL, fromTypes(superType, superType));
FunctionHandle equal = functionAndTypeManager.resolveOperator(EQUAL, fromTypes(superType, superType));
return changed(call(EQUAL.name(), equal, BOOLEAN, valueExpression, patternExpression).accept(this, context));
}
return notChanged();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ public LayoutConstraintEvaluatorForRowExpression(Metadata metadata, Session sess
{
this.assignments = assignments;

evaluator = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), OPTIMIZED);
evaluator = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session.toConnectorSession(), OPTIMIZED);
arguments = VariablesExtractor.extractUnique(expression).stream()
.map(assignments::get)
.collect(toImmutableSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ private RowExpression evaluateMinMax(FunctionMetadata aggregationFunctionMetadat
scalarFunctionName,
returnType,
partitionedArguments),
metadata,
metadata.getFunctionAndTypeManager(),
connectorSession);
reducedArguments.add(constant(reducedValue, returnType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ public ActualProperties visitProject(ProjectNode node, List<ActualProperties> in
// to take advantage of constant-folding for complex expressions
// However, that currently causes errors when those expressions operate on arrays or row types
// ("ROW comparison not supported for fields with null elements", etc)
Object value = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), OPTIMIZED).optimize();
Object value = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session.toConnectorSession(), OPTIMIZED).optimize();

if (value instanceof VariableReferenceExpression) {
ConstantExpression existingConstantValue = constants.get(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,13 +637,13 @@ private Optional<NormalizedSimpleComparison> toNormalizedSimpleComparison(Operat
left = leftExpression;
}
else {
left = new RowExpressionInterpreter(leftExpression, metadata, session, OPTIMIZED).optimize();
left = new RowExpressionInterpreter(leftExpression, metadata.getFunctionAndTypeManager(), session, OPTIMIZED).optimize();
}
if (rightExpression instanceof VariableReferenceExpression) {
right = rightExpression;
}
else {
right = new RowExpressionInterpreter(rightExpression, metadata, session, OPTIMIZED).optimize();
right = new RowExpressionInterpreter(rightExpression, metadata.getFunctionAndTypeManager(), session, OPTIMIZED).optimize();
}

if (left instanceof RowExpression == right instanceof RowExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ public RowExpressionOptimizer(Metadata metadata)
public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session)
{
if (level.ordinal() <= OPTIMIZED.ordinal()) {
return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, metadata, session, level).optimize(), rowExpression.getType());
return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, metadata.getFunctionAndTypeManager(), session, level).optimize(), rowExpression.getType());
}
throw new IllegalArgumentException("Not supported optimization level: " + level);
}

@Override
public Object optimize(RowExpression expression, Level level, ConnectorSession session, Function<VariableReferenceExpression, Object> variableResolver)
{
RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, metadata, session, level);
RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session, level);
return interpreter.optimize(variableResolver::apply);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ private static Object optimize(Expression expression)

private static Object optimize(RowExpression expression, Level level)
{
return new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level).optimize(variable -> {
return new RowExpressionInterpreter(expression, METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession(), level).optimize(variable -> {
Symbol symbol = new Symbol(variable.getName());
Object value = symbolConstant(symbol);
if (value == null) {
Expand Down Expand Up @@ -1850,7 +1850,7 @@ private static Object evaluate(Expression expression, boolean deterministic)
{
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP);
Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate();
Object rowExpressionResult = rowExpressionInterpreter(TRANSLATOR.translateAndOptimize(expression), METADATA, TEST_SESSION.toConnectorSession()).evaluate();
Object rowExpressionResult = rowExpressionInterpreter(TRANSLATOR.translateAndOptimize(expression), METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession()).evaluate();

if (deterministic) {
assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult);
Expand Down
Loading

0 comments on commit 3cd2744

Please sign in to comment.