Skip to content

Commit

Permalink
Fix tests; remove polymorphic division function that breaks backward …
Browse files Browse the repository at this point in the history
…compatibility
  • Loading branch information
yashmayya committed Sep 26, 2024
1 parent 15330eb commit ff6a73b
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ public class ArithmeticFunctions {
private ArithmeticFunctions() {
}

@ScalarFunction(names = {"div", "divide"})
public static double divide(double a, double b) {
return a / b;
}

@ScalarFunction(names = {"div", "divide"})
public static double divide(double a, double b, double defaultValue) {
return (b == 0) ? defaultValue : a / b;
}

@ScalarFunction
public static double mod(double a, double b) {
return a % b;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2363,14 +2363,6 @@ public void testCompileTimeExpression() {
result = expression.getLiteral().getLongValue();
Assert.assertTrue(result >= lowerBound && result <= upperBound);

expression = compileToExpression("now() / 1");
Assert.assertNotNull(expression.getFunctionCall());
expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
Assert.assertNotNull(expression.getLiteral());
upperBound = System.currentTimeMillis();
result = expression.getLiteral().getLongValue();
Assert.assertTrue(result >= lowerBound && result <= upperBound);

lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
expression = compileToExpression("to_epoch_hours(now() + 3600000)");
Assert.assertNotNull(expression.getFunctionCall());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ public void testPostAggregationFunction() {
// Plus
PostAggregationFunction function =
new PostAggregationFunction("plus", new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG});
assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
assertEquals(function.invoke(new Object[]{1, 2L}), 3.0);
assertEquals(function.getResultType(), ColumnDataType.LONG);
assertEquals(function.invoke(new Object[]{1, 2L}), 3L);

// Minus
function = new PostAggregationFunction("MINUS", new ColumnDataType[]{ColumnDataType.FLOAT, ColumnDataType.DOUBLE});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2040,54 +2040,61 @@ public void testGroupByUDF(boolean useMultiStageQueryEngine)
assertEquals(row.get(0).asLong(), 16138 * 24);
assertEquals(row.get(1).asLong(), 605);

if (useMultiStageQueryEngine) {
query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*) FROM mytable "
+ "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY COUNT(*) DESC";
} else {
query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM mytable "
+ "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) DESC";
}
query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
+ "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
assertEquals(row.get(1).asLong(), 605);
assertEquals(row.get(0).asInt(), 5);
assertEquals(row.get(1).asLong(), 115545);

query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+ "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable GROUP BY "
+ "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
assertEquals(rows.size(), 3);
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 - 25);
assertEquals(row.get(1).asLong(), 605);
assertEquals(row.get(0).asInt(), 0);
assertEquals(row.get(1).asLong(), 114895);
row = rows.get(1);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 1);
assertEquals(row.get(1).asLong(), 648);
row = rows.get(2);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 2);
assertEquals(row.get(1).asLong(), 2);

if (useMultiStageQueryEngine) {
query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable "
+ "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC";
if (useMultiStageQueryEngine()) {
query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable "
+ "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC";
} else {
query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
+ "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
+ "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
}
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"STRING\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
assertEquals(rows.size(), 2);
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
assertEquals(row.get(1).asLong(), 605);
assertEquals(row.get(0).asText(), "ORD");
assertEquals(row.get(1).asLong(), 336);
row = rows.get(1);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asText(), "DFW");
assertEquals(row.get(1).asLong(), 316);

query = "SELECT div(DaysSinceEpoch,2), COUNT(*) FROM mytable "
+ "GROUP BY div(DaysSinceEpoch,2) ORDER BY COUNT(*) DESC";
Expand All @@ -2101,62 +2108,92 @@ public void testGroupByUDF(boolean useMultiStageQueryEngine)
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 / 2);
assertEquals(row.get(1).asLong(), 605);
}

query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
+ "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
@Test
public void testGroupByUDFV1() throws Exception {
setUseMultiStageQueryEngine(false);
String query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM mytable "
+ "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) DESC";
JsonNode response = postQuery(query);
JsonNode resultTable = response.get("resultTable");
JsonNode dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
JsonNode rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
JsonNode row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
assertEquals(row.get(1).asLong(), 605);

query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+ "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 5);
assertEquals(row.get(1).asLong(), 115545);
assertEquals(row.get(0).asDouble(), 16138.0 - 25);
assertEquals(row.get(1).asLong(), 605);

query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable GROUP BY "
+ "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC";
query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
+ "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]");
rows = resultTable.get("rows");
assertEquals(rows.size(), 3);
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 0);
assertEquals(row.get(1).asLong(), 114895);
row = rows.get(1);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 1);
assertEquals(row.get(1).asLong(), 648);
row = rows.get(2);
assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
assertEquals(row.get(1).asLong(), 605);
}

@Test
public void testGroupByUDFV2() throws Exception {
setUseMultiStageQueryEngine(true);
String query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*) FROM mytable "
+ "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY COUNT(*) DESC";
JsonNode response = postQuery(query);
JsonNode resultTable = response.get("resultTable");
JsonNode dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
JsonNode rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
JsonNode row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asInt(), 2);
assertEquals(row.get(1).asLong(), 2);
assertEquals(row.get(0).asInt(), 16138 + 16138 + 15);
assertEquals(row.get(1).asLong(), 605);

if (useMultiStageQueryEngine()) {
query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable "
+ "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC";
} else {
query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
+ "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
}
query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+ "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"STRING\",\"LONG\"]");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
assertEquals(rows.size(), 2);
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asText(), "ORD");
assertEquals(row.get(1).asLong(), 336);
row = rows.get(1);
assertEquals(row.get(0).asInt(), 16138 - 25);
assertEquals(row.get(1).asLong(), 605);

query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable "
+ "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
assertEquals(row.get(0).asText(), "DFW");
assertEquals(row.get(1).asLong(), 316);
assertEquals(row.get(0).asInt(), 16138 * 24 * 3600);
assertEquals(row.get(1).asLong(), 605);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ public static PinotOperatorTable instance() {
Pair.of(SqlStdOperatorTable.GREATER_THAN, List.of("GREATER_THAN")),
Pair.of(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, List.of("GREATER_THAN_OR_EQUAL")),
Pair.of(SqlStdOperatorTable.LESS_THAN, List.of("LESS_THAN")),
Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, List.of("LESS_THAN_OR_EQUAL"))
Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, List.of("LESS_THAN_OR_EQUAL")),
Pair.of(SqlStdOperatorTable.PLUS, List.of("ADD")),
Pair.of(SqlStdOperatorTable.MINUS, List.of("SUB")),
Pair.of(SqlStdOperatorTable.MULTIPLY, List.of("MULT", "TIMES"))
);

/**
Expand Down
Loading

0 comments on commit ff6a73b

Please sign in to comment.