Skip to content

Commit

Permalink
Add real operator support for new nan definition
Browse files Browse the repository at this point in the history
This adds support for the new nan definition to =, <>, <, >, <=,>=,
BETWEEN, IN, NOT IN for real types.
  • Loading branch information
rschlussel committed Jun 6, 2024
1 parent aa5de0c commit 554f1c2
Show file tree
Hide file tree
Showing 8 changed files with 452 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static java.lang.Double.doubleToLongBits;
import static java.lang.Float.floatToIntBits;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -258,4 +259,44 @@ public static int doubleCompare(double a, double b)
long bBits = doubleToLongBits(b);
return Long.compare(aBits, bBits);
}

public static boolean realEquals(float a, float b)
{
// the first check ensures +0 == -0 is true. the second ensures that NaN == NaN is true
// for all other cases a == b and floatToIntBits(a) == floatToIntBits(b) will return
// the same result
// floatToIntBits converts all NaNs to the same representation
return a == b || floatToIntBits(a) == floatToIntBits(b);
}

public static long realHashCode(float value)
{
// canonicalize +0 and -0 to a single value
value = value == -0 ? 0 : value;
// floatToIntBits converts all NaNs to the same representation
return AbstractLongType.hash(floatToIntBits(value));
}

public static int realCompare(float a, float b)
{
// these three ifs can only be true if neither value is NaN
if (a < b) {
return -1;
}
if (a > b) {
return 1;
}
// this check ensure floatCompare(+0, -0) will return 0
// if we just did floatToIntBits comparison, then they
// would not compare as equal
if (a == b) {
return 0;
}

// this ensures that realCompare(NaN, NaN) will return 0
// floatToIntBits converts all NaNs to the same representation
int aBits = floatToIntBits(a);
int bBits = floatToIntBits(b);
return Integer.compare(aBits, bBits);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
import static com.facebook.presto.common.type.TypeUtils.doubleCompare;
import static com.facebook.presto.common.type.TypeUtils.doubleEquals;
import static com.facebook.presto.common.type.TypeUtils.doubleHashCode;
import static com.facebook.presto.common.type.TypeUtils.realCompare;
import static com.facebook.presto.common.type.TypeUtils.realEquals;
import static com.facebook.presto.common.type.TypeUtils.realHashCode;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static java.lang.Double.longBitsToDouble;
import static java.lang.Float.intBitsToFloat;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -82,4 +86,30 @@ public void testDoubleCompare()
//0x7ff8123412341234L is a different representation of NaN
assertEquals(doubleCompare(Double.NaN, longBitsToDouble(0x7ff8123412341234L)), 0);
}

@Test
public void testRealHashCode()
{
assertEquals(realHashCode(0), realHashCode(Float.parseFloat("-0")));
// 0x7fc01234 is a different representation of NaN
assertEquals(realHashCode(Float.NaN), realHashCode(intBitsToFloat(0x7fc01234)));
}

@Test
public void testRealEquals()
{
assertTrue(realEquals(0, Float.parseFloat("-0")));
assertTrue(realEquals(Float.NaN, Float.NaN));
// 0x7fc01234 is a different representation of NaN
assertTrue(realEquals(Float.NaN, intBitsToFloat(0x7fc01234)));
}

@Test
public void testRealCompare()
{
assertEquals(realCompare(0, Float.parseFloat("-0")), 0);
assertEquals(realCompare(Float.NaN, Float.NaN), 0);
// 0x7fc01234 is a different representation of NaN
assertEquals(realCompare(Float.NaN, intBitsToFloat(0x7fc01234)), 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,12 @@
import com.facebook.presto.type.IpPrefixOperators;
import com.facebook.presto.type.KllSketchOperators;
import com.facebook.presto.type.LegacyDoubleComparisonOperators;
import com.facebook.presto.type.LegacyRealComparisonOperators;
import com.facebook.presto.type.LikeFunctions;
import com.facebook.presto.type.LongEnumOperators;
import com.facebook.presto.type.MapParametricType;
import com.facebook.presto.type.QuantileDigestOperators;
import com.facebook.presto.type.RealComparisonOperators;
import com.facebook.presto.type.RealOperators;
import com.facebook.presto.type.SfmSketchOperators;
import com.facebook.presto.type.SmallintOperators;
Expand Down Expand Up @@ -800,20 +802,22 @@ private List<? extends SqlFunction> getBuiltInFunctions(FeaturesConfig featuresC
.scalar(SmallintOperators.SmallintDistinctFromOperator.class)
.scalars(TinyintOperators.class)
.scalar(TinyintOperators.TinyintDistinctFromOperator.class)
.scalars(DoubleOperators.class);
.scalars(DoubleOperators.class)
.scalars(RealOperators.class);

if (featuresConfig.getUseNewNanDefinition()) {
builder.scalars(DoubleComparisonOperators.class)
.scalar(DoubleComparisonOperators.DoubleDistinctFromOperator.class);
.scalar(DoubleComparisonOperators.DoubleDistinctFromOperator.class)
.scalars(RealComparisonOperators.class)
.scalar(RealComparisonOperators.RealDistinctFromOperator.class);
}
else {

builder.scalars(LegacyDoubleComparisonOperators.class)
.scalar(LegacyDoubleComparisonOperators.DoubleDistinctFromOperator.class);
.scalar(LegacyDoubleComparisonOperators.DoubleDistinctFromOperator.class)
.scalars(LegacyRealComparisonOperators.class)
.scalar(LegacyRealComparisonOperators.RealDistinctFromOperator.class);
}
builder.scalars(RealOperators.class)
.scalar(RealOperators.RealDistinctFromOperator.class)
.scalars(VarcharOperators.class)
builder.scalars(VarcharOperators.class)
.scalar(VarcharOperators.VarcharDistinctFromOperator.class)
.scalars(VarbinaryOperators.class)
.scalar(VarbinaryOperators.VarbinaryDistinctFromOperator.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.type;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.AbstractIntType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.BlockIndex;
import com.facebook.presto.spi.function.BlockPosition;
import com.facebook.presto.spi.function.IsNull;
import com.facebook.presto.spi.function.ScalarOperator;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;

import static com.facebook.presto.common.function.OperatorType.BETWEEN;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.HASH_CODE;
import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.type.RealType.REAL;
import static java.lang.Float.floatToIntBits;
import static java.lang.Float.intBitsToFloat;

@Deprecated
public final class LegacyRealComparisonOperators
{
private LegacyRealComparisonOperators()
{
}

@ScalarOperator(EQUAL)
@SqlType(StandardTypes.BOOLEAN)
@SqlNullable
public static Boolean equal(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) == intBitsToFloat((int) right);
}

@ScalarOperator(NOT_EQUAL)
@SqlType(StandardTypes.BOOLEAN)
@SqlNullable
public static Boolean notEqual(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) != intBitsToFloat((int) right);
}

@ScalarOperator(LESS_THAN)
@SqlType(StandardTypes.BOOLEAN)
public static boolean lessThan(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) < intBitsToFloat((int) right);
}

@ScalarOperator(LESS_THAN_OR_EQUAL)
@SqlType(StandardTypes.BOOLEAN)
public static boolean lessThanOrEqual(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) <= intBitsToFloat((int) right);
}

@ScalarOperator(GREATER_THAN)
@SqlType(StandardTypes.BOOLEAN)
public static boolean greaterThan(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) > intBitsToFloat((int) right);
}

@ScalarOperator(GREATER_THAN_OR_EQUAL)
@SqlType(StandardTypes.BOOLEAN)
public static boolean greaterThanOrEqual(@SqlType(StandardTypes.REAL) long left, @SqlType(StandardTypes.REAL) long right)
{
return intBitsToFloat((int) left) >= intBitsToFloat((int) right);
}

@ScalarOperator(BETWEEN)
@SqlType(StandardTypes.BOOLEAN)
public static boolean between(@SqlType(StandardTypes.REAL) long value, @SqlType(StandardTypes.REAL) long min, @SqlType(StandardTypes.REAL) long max)
{
return intBitsToFloat((int) min) <= intBitsToFloat((int) value) &&
intBitsToFloat((int) value) <= intBitsToFloat((int) max);
}

@ScalarOperator(HASH_CODE)
@SqlType(StandardTypes.BIGINT)
public static long hashCode(@SqlType(StandardTypes.REAL) long value)
{
return AbstractIntType.hash(floatToIntBits(intBitsToFloat((int) value)));
}

@ScalarOperator(IS_DISTINCT_FROM)
public static class RealDistinctFromOperator
{
@SqlType(StandardTypes.BOOLEAN)
public static boolean isDistinctFrom(
@SqlType(StandardTypes.REAL) long left,
@IsNull boolean leftNull,
@SqlType(StandardTypes.REAL) long right,
@IsNull boolean rightNull)
{
if (leftNull != rightNull) {
return true;
}
if (leftNull) {
return false;
}
float leftFloat = intBitsToFloat((int) left);
float rightFloat = intBitsToFloat((int) right);
if (Float.isNaN(leftFloat) && Float.isNaN(rightFloat)) {
return false;
}
return notEqual(left, right);
}

@SqlType(StandardTypes.BOOLEAN)
public static boolean isDistinctFrom(
@BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block left,
@BlockIndex int leftPosition,
@BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block right,
@BlockIndex int rightPosition)
{
if (left.isNull(leftPosition) != right.isNull(rightPosition)) {
return true;
}
if (left.isNull(leftPosition)) {
return false;
}
return notEqual(REAL.getLong(left, leftPosition), REAL.getLong(right, rightPosition));
}
}
}
Loading

0 comments on commit 554f1c2

Please sign in to comment.