Skip to content

Commit

Permalink
Fix Greatest/Least for new NaN definition
Browse files Browse the repository at this point in the history
  • Loading branch information
rschlussel committed Jun 6, 2024
1 parent b17a916 commit e64a504
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.annotation.UsedByGeneratedCode;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.ClassDefinition;
Expand All @@ -30,7 +29,6 @@
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
Expand All @@ -50,7 +48,6 @@
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.Signature.orderableTypeParameter;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
Expand All @@ -61,16 +58,13 @@
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public abstract class AbstractGreatestLeast
extends SqlScalarFunction
{
private static final MethodHandle CHECK_NOT_NAN = methodHandle(AbstractGreatestLeast.class, "checkNotNaN", String.class, double.class);

private final OperatorType operatorType;

protected AbstractGreatestLeast(QualifiedObjectName name, OperatorType operatorType)
Expand Down Expand Up @@ -120,14 +114,6 @@ public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariab
methodHandle);
}

@UsedByGeneratedCode
public static void checkNotNaN(String name, double value)
{
if (Double.isNaN(value)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Invalid argument to %s(): NaN", name));
}
}

private Class<?> generate(List<Class<?>> javaTypes, Type type, MethodHandle compareMethod)
{
checkCondition(javaTypes.size() <= 127, NOT_SUPPORTED, "Too many arguments for function call %s()", getSignature().getNameSuffix());
Expand Down Expand Up @@ -160,7 +146,6 @@ private Class<?> generate(List<Class<?>> javaTypes, Type type, MethodHandle comp
if (type.getTypeSignature().getBase().equals(StandardTypes.DOUBLE)) {
for (Parameter parameter : parameters) {
body.append(parameter);
body.append(invoke(binder.bind(CHECK_NOT_NAN.bindTo(getSignature().getNameSuffix())), "checkNotNaN"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ protected AbstractTestFunctions(FeaturesConfig config)
protected AbstractTestFunctions(Session session, FeaturesConfig config)
{
this.session = requireNonNull(session, "session is null");
this.config = requireNonNull(config, "config is null").setLegacyLogFunction(true);
this.config = requireNonNull(config, "config is null")
.setLegacyLogFunction(true)
.setUseNewNanDefinition(true);
}

@BeforeClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,6 @@ public void testGreatest()
assertFunction("greatest(1.0, 2.0E0)", DOUBLE, 2.0);
assertDecimalFunction("greatest(5, 4, 3.0, 2)", decimal("0000000005.0"));

// invalid
assertInvalidFunction("greatest(1.5E0, 0.0E0 / 0.0E0)", "Invalid argument to greatest(): NaN");

// argument count limit
tryEvaluateWithAll("greatest(" + Joiner.on(", ").join(nCopies(127, "rand()")) + ")", DOUBLE);
assertNotSupported(
Expand Down Expand Up @@ -1198,15 +1195,14 @@ public void testLeast()
assertFunction("least(1.0, 2.0E0)", DOUBLE, 1.0);
assertDecimalFunction("least(5, 4, 3.0, 2)", decimal("0000000002.0"));

// invalid
assertInvalidFunction("least(1.5E0, 0.0E0 / 0.0E0)", "Invalid argument to least(): NaN");
assertFunction("least(1.5E0, 0.0E0 / 0.0E0)", DOUBLE, 1.5E0);
}

@Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "\\QInvalid argument to greatest(): NaN\\E")
@Test
public void testGreatestWithNaN()
{
functionAssertions.tryEvaluate("greatest(1.5E0, 0.0E0 / 0.0E0)", DOUBLE);
functionAssertions.tryEvaluate("greatest(1.5E0, REAL '0.0' / REAL '0.0')", DOUBLE);
assertFunction("greatest(1.5E0, 0.0E0 / 0.0E0)", DOUBLE, Double.NaN);
assertFunction("greatest(REAL '1.5E0', REAL '0.0' / REAL '0.0')", REAL, Float.NaN);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,29 @@ public void testMaxByN()
"SELECT ARRAY [REAL '0.0', CAST(infinity() AS REAL)], ARRAY[REAL '2.0', CAST(nan() AS REAL)], ARRAY[REAL '1.0', REAL '0.0']");
}

@Test
public void testGreatest()
{
assertQueryWithSameQueryRunner("SELECT GREATEST(1.5E0, nan())", "SELECT nan()");
assertQueryWithSameQueryRunner(
format("SELECT greatest(%s, %s, %s) FROM %s", DOUBLE_NAN_FIRST_COLUMN, DOUBLE_NAN_MIDDLE_COLUMN, DOUBLE_NAN_LAST_COLUMN, DOUBLE_NANS_TABLE_NAME),
"SELECT * FROM (VALUES (nan()), (nan()), (infinity()), (nan()))");
assertQueryWithSameQueryRunner(
format("SELECT greatest(%s, %s, %s) FROM %s", REAL_NAN_FIRST_COLUMN, REAL_NAN_MIDDLE_COLUMN, REAL_NAN_LAST_COLUMN, REAL_NANS_TABLE_NAME),
"SELECT * FROM (VALUES (CAST(nan() AS REAL)), (CAST(nan() AS REAL)), CAST(infinity() AS REAL), (CAST(nan() AS REAL)))");
}

@Test
public void testLeast()
{
assertQueryWithSameQueryRunner(
format("SELECT least(%s, %s, %s) FROM %s", DOUBLE_NAN_FIRST_COLUMN, DOUBLE_NAN_MIDDLE_COLUMN, DOUBLE_NAN_LAST_COLUMN, DOUBLE_NANS_TABLE_NAME),
"SELECT * FROM (VALUES (DOUBLE '0.0'), (DOUBLE '0.0'), (DOUBLE '0.0'), (DOUBLE '-4.0'))");
assertQueryWithSameQueryRunner(
format("SELECT least(%s, %s, %s) FROM %s", REAL_NAN_FIRST_COLUMN, REAL_NAN_MIDDLE_COLUMN, REAL_NAN_LAST_COLUMN, REAL_NANS_TABLE_NAME),
"SELECT * FROM (VALUES (REAL '0.0'), (REAL '0.0'), (REAL'0.0'), REAL'-4.0')");
}

@Test
public void testDoubleSetAgg()
{
Expand Down

0 comments on commit e64a504

Please sign in to comment.