From 2002953722a94a830e0cf9adcff4149e0a4217a6 Mon Sep 17 00:00:00 2001 From: benitakbritto <> Date: Fri, 24 Jun 2022 11:19:38 -0700 Subject: [PATCH] Implement Presto aggregate function map_union (#1827) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/1827 Test Plan: ## Unit Test ``` [~/fbsource/fbcode]$ buck test velox/functions/prestosql/aggregates/tests:test ``` ## End-to-end Test ``` [~/fbsource/fbcode/presto_cpp/public_tld/java/presto-native-tests]$ JAVA_HOME=/usr/local/fb-jdk-8 mvn -Dmaven.compiler.fork=true -Dmaven.compiler.executable=/usr/local/fb-jdk-8/bin/javac -Djava.net.preferIPv6Addresses=true -DPRESTO_SERVER=$HOME/fbsource/fbcode/buck-out/gen/presto_cpp/main/presto_server -DDATA_DIR=/tmp -Duser.timezone=America/Bahia_Banderas -Dtest=TestHiveAggregationQueries test ``` Reviewed By: kagamiori Differential Revision: D37088301 Pulled By: benitakbritto fbshipit-source-id: 209ee7b1ad2e41608226215f3bcb788d687af8e4 --- velox/docs/functions/aggregate.rst | 6 + .../prestosql/aggregates/AggregateNames.h | 1 + .../prestosql/aggregates/CMakeLists.txt | 3 + .../prestosql/aggregates/MapAggAggregate.cpp | 219 +----------------- .../prestosql/aggregates/MapAggregateBase.cpp | 176 ++++++++++++++ .../prestosql/aggregates/MapAggregateBase.h | 115 +++++++++ .../aggregates/MapUnionAggregate.cpp | 74 ++++++ .../prestosql/aggregates/tests/CMakeLists.txt | 1 + .../tests/MapUnionAggregationTest.cpp | 195 ++++++++++++++++ 9 files changed, 574 insertions(+), 216 deletions(-) create mode 100644 velox/functions/prestosql/aggregates/MapAggregateBase.cpp create mode 100644 velox/functions/prestosql/aggregates/MapAggregateBase.h create mode 100644 velox/functions/prestosql/aggregates/MapUnionAggregate.cpp create mode 100644 velox/functions/prestosql/aggregates/tests/MapUnionAggregationTest.cpp diff --git a/velox/docs/functions/aggregate.rst b/velox/docs/functions/aggregate.rst index e2e612d4da67..621d37b43b73 100644 --- a/velox/docs/functions/aggregate.rst +++ b/velox/docs/functions/aggregate.rst @@ -104,6 +104,12 @@ Map Aggregate Functions Returns a map created from the input ``key`` / ``value`` pairs. +.. function:: map_union(map(K,V)) -> map(K,V) + + Returns the union of all the input ``maps``. + If a ``key`` is found in multiple input ``maps``, + that ``key’s`` ``value`` in the resulting ``map`` comes from an arbitrary input ``map``. + Approximate Aggregate Functions ------------------------------- diff --git a/velox/functions/prestosql/aggregates/AggregateNames.h b/velox/functions/prestosql/aggregates/AggregateNames.h index 8724c8c49c21..0621d68761b4 100644 --- a/velox/functions/prestosql/aggregates/AggregateNames.h +++ b/velox/functions/prestosql/aggregates/AggregateNames.h @@ -37,6 +37,7 @@ const char* const kCovarSamp = "covar_samp"; const char* const kEvery = "every"; const char* const kHistogram = "histogram"; const char* const kMapAgg = "map_agg"; +const char* const kMapUnion = "map_union"; const char* const kMax = "max"; const char* const kMaxBy = "max_by"; const char* const kMerge = "merge"; diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index fb545f710c4e..6cffd933c3c5 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -31,6 +31,9 @@ add_library( ChecksumAggregate.cpp HistogramAggregate.cpp MapAggAggregate.cpp + MapAggregateBase.h + MapAggregateBase.cpp + MapUnionAggregate.cpp MinMaxAggregates.cpp MinMaxByAggregates.cpp CountAggregate.cpp diff --git a/velox/functions/prestosql/aggregates/MapAggAggregate.cpp b/velox/functions/prestosql/aggregates/MapAggAggregate.cpp index c259186c6a83..1f0ec4f82058 100644 --- a/velox/functions/prestosql/aggregates/MapAggAggregate.cpp +++ b/velox/functions/prestosql/aggregates/MapAggAggregate.cpp @@ -13,95 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/exec/ContainerRowSerde.h" -#include "velox/expression/FunctionSignature.h" -#include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/functions/prestosql/aggregates/ValueList.h" +#include "velox/functions/prestosql/aggregates/MapAggregateBase.h" namespace facebook::velox::aggregate { namespace { - -struct MapAccumulator { - ValueList keys; - ValueList values; -}; - // See documentation at // https://prestodb.io/docs/current/functions/aggregate.html -class MapAggAggregate : public exec::Aggregate { +class MapAggAggregate : public aggregate::MapAggregateBase { public: - explicit MapAggAggregate(TypePtr resultType) : Aggregate(resultType) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(MapAccumulator); - } - - bool isFixedSize() const override { - return false; - } - - void initializeNewGroups( - char** groups, - folly::Range indices) override { - for (auto index : indices) { - new (groups[index] + offset_) MapAccumulator(); - } - } - - void finalize(char** groups, int32_t numGroups) override { - for (auto i = 0; i < numGroups; i++) { - value(groups[i])->keys.finalize(allocator_); - value(groups[i])->values.finalize(allocator_); - } - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto mapVector = (*result)->as(); - VELOX_CHECK(mapVector); - mapVector->resize(numGroups); - - auto mapKeys = mapVector->mapKeys(); - auto mapValues = mapVector->mapValues(); - - auto numElements = countElements(groups, numGroups); - mapKeys->resize(numElements); - mapValues->resize(numElements); - - auto* rawNulls = getRawNulls(mapVector); - vector_size_t offset = 0; - for (int32_t i = 0; i < numGroups; ++i) { - char* group = groups[i]; - clearNull(rawNulls, i); - - auto accumulator = value(group); - auto mapSize = accumulator->keys.size(); - if (mapSize) { - ValueListReader keysReader(accumulator->keys); - ValueListReader valuesReader(accumulator->values); - for (auto index = 0; index < mapSize; ++index) { - keysReader.next(*mapKeys, offset + index); - valuesReader.next(*mapValues, offset + index); - } - mapVector->setOffsetAndSize(i, offset, mapSize); - offset += mapSize; - } else { - mapVector->setOffsetAndSize(i, offset, 0); - } - } - - // canonicalize requires a singly referenced MapVector. std::move - // inside the cast does not clear *result, so we clear this - // manually. - auto mapVectorPtr = std::static_pointer_cast(std::move(*result)); - *result = nullptr; - *result = removeDuplicates(mapVectorPtr); - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - extractValues(groups, numGroups, result); - } + explicit MapAggAggregate(TypePtr resultType) : MapAggregateBase(resultType) {} void addRawInput( char** groups, @@ -123,29 +43,6 @@ class MapAggAggregate : public exec::Aggregate { }); } - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedIntermediate_.decode(*args[0], rows); - - auto mapVector = decodedIntermediate_.base()->as(); - auto& mapKeys = mapVector->mapKeys(); - auto& mapValues = mapVector->mapValues(); - rows.applyToSelected([&](vector_size_t row) { - auto group = groups[row]; - auto accumulator = value(group); - - auto decodedRow = decodedIntermediate_.index(row); - auto offset = mapVector->offsetAt(decodedRow); - auto size = mapVector->sizeAt(decodedRow); - auto tracker = trackRowSize(group); - accumulator->keys.appendRange(mapKeys, offset, size, allocator_); - accumulator->values.appendRange(mapValues, offset, size, allocator_); - }); - } - void addSingleGroupRawInput( char* group, const SelectivityVector& rows, @@ -166,116 +63,6 @@ class MapAggAggregate : public exec::Aggregate { } }); } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedIntermediate_.decode(*args[0], rows); - - auto accumulator = value(group); - auto mapVector = decodedIntermediate_.base()->as(); - auto& keys = accumulator->keys; - auto& values = accumulator->values; - - auto& mapKeys = mapVector->mapKeys(); - auto& mapValues = mapVector->mapValues(); - rows.applyToSelected([&](vector_size_t row) { - auto decodedRow = decodedIntermediate_.index(row); - auto offset = mapVector->offsetAt(decodedRow); - auto size = mapVector->sizeAt(decodedRow); - keys.appendRange(mapKeys, offset, size, allocator_); - values.appendRange(mapValues, offset, size, allocator_); - }); - } - - void destroy(folly::Range groups) override { - for (auto group : groups) { - auto accumulator = value(group); - accumulator->keys.free(allocator_); - accumulator->values.free(allocator_); - } - } - - private: - vector_size_t countElements(char** groups, int32_t numGroups) const { - vector_size_t size = 0; - for (int32_t i = 0; i < numGroups; ++i) { - size += value(groups[i])->keys.size(); - } - return size; - } - - VectorPtr removeDuplicates(MapVectorPtr& mapVector) const { - MapVector::canonicalize(mapVector); - - auto offsets = mapVector->rawOffsets(); - auto sizes = mapVector->rawSizes(); - auto mapKeys = mapVector->mapKeys(); - - auto numRows = mapVector->size(); - auto numElements = mapKeys->size(); - - BufferPtr newSizes; - vector_size_t* rawNewSizes = nullptr; - - BufferPtr elementIndices; - vector_size_t* rawElementIndices = nullptr; - - // Check for duplicate keys - for (vector_size_t row = 0; row < numRows; row++) { - auto offset = offsets[row]; - auto size = sizes[row]; - auto duplicateCnt = 0; - for (vector_size_t i = 1; i < size; i++) { - if (mapKeys->equalValueAt(mapKeys.get(), offset + i, offset + i - 1)) { - // duplicate key - duplicateCnt++; - if (!rawNewSizes) { - newSizes = - allocateSizes(mapVector->mapKeys()->size(), mapVector->pool()); - rawNewSizes = newSizes->asMutable(); - - elementIndices = allocateIndices( - mapVector->mapKeys()->size(), mapVector->pool()); - rawElementIndices = elementIndices->asMutable(); - - memcpy(rawNewSizes, sizes, row * sizeof(vector_size_t)); - std::iota(rawElementIndices, rawElementIndices + offset + i, 0); - } - } else if (rawNewSizes) { - rawElementIndices[offset + i - duplicateCnt] = offset + i; - } - } - if (rawNewSizes) { - rawNewSizes[row] = size - duplicateCnt; - } - }; - - if (rawNewSizes) { - return std::make_shared( - mapVector->pool(), - mapVector->type(), - mapVector->nulls(), - mapVector->size(), - mapVector->offsets(), - newSizes, - BaseVector::wrapInDictionary( - BufferPtr(nullptr), elementIndices, numElements, mapKeys), - BaseVector::wrapInDictionary( - BufferPtr(nullptr), - elementIndices, - numElements, - mapVector->mapValues())); - } else { - return mapVector; - } - } - - DecodedVector decodedKeys_; - DecodedVector decodedValues_; - DecodedVector decodedIntermediate_; }; bool registerMapAggAggregate(const std::string& name) { diff --git a/velox/functions/prestosql/aggregates/MapAggregateBase.cpp b/velox/functions/prestosql/aggregates/MapAggregateBase.cpp new file mode 100644 index 000000000000..f286eb4b05ed --- /dev/null +++ b/velox/functions/prestosql/aggregates/MapAggregateBase.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/aggregates/MapAggregateBase.h" +#include + +namespace facebook::velox::aggregate { + +void MapAggregateBase::extractValues( + char** groups, + int32_t numGroups, + VectorPtr* result) { + auto mapVector = (*result)->as(); + VELOX_CHECK(mapVector); + mapVector->resize(numGroups); + auto mapKeys = mapVector->mapKeys(); + auto mapValues = mapVector->mapValues(); + auto numElements = countElements(groups, numGroups); + mapKeys->resize(numElements); + mapValues->resize(numElements); + + auto* rawNulls = getRawNulls(mapVector); + vector_size_t offset = 0; + + for (int32_t i = 0; i < numGroups; ++i) { + char* group = groups[i]; + clearNull(rawNulls, i); + + auto accumulator = value(group); + auto mapSize = accumulator->keys.size(); + if (mapSize) { + ValueListReader keysReader(accumulator->keys); + ValueListReader valuesReader(accumulator->values); + for (auto index = 0; index < mapSize; ++index) { + keysReader.next(*mapKeys, offset + index); + valuesReader.next(*mapValues, offset + index); + } + mapVector->setOffsetAndSize(i, offset, mapSize); + offset += mapSize; + } else { + mapVector->setOffsetAndSize(i, offset, 0); + } + } + + // Canonicalize requires a singly referenced MapVector. std::move + // inside the cast does not clear *result, so we clear this + // manually. + auto mapVectorPtr = std::static_pointer_cast(std::move(*result)); + *result = nullptr; + *result = removeDuplicates(mapVectorPtr); +} + +VectorPtr MapAggregateBase::removeDuplicates(MapVectorPtr& mapVector) const { + MapVector::canonicalize(mapVector); + + auto offsets = mapVector->rawOffsets(); + auto sizes = mapVector->rawSizes(); + auto mapKeys = mapVector->mapKeys(); + + auto numRows = mapVector->size(); + auto numElements = mapKeys->size(); + + BufferPtr newSizes; + vector_size_t* rawNewSizes = nullptr; + + BufferPtr elementIndices; + vector_size_t* rawElementIndices = nullptr; + + // Check for duplicate keys. + for (vector_size_t row = 0; row < numRows; row++) { + auto offset = offsets[row]; + auto size = sizes[row]; + auto duplicateCnt = 0; + for (vector_size_t i = 1; i < size; i++) { + if (mapKeys->equalValueAt(mapKeys.get(), offset + i, offset + i - 1)) { + // Duplicate key found. + duplicateCnt++; + if (!rawNewSizes) { + newSizes = allocateSizes(numElements, mapVector->pool()); + rawNewSizes = newSizes->asMutable(); + + elementIndices = allocateIndices(numElements, mapVector->pool()); + rawElementIndices = elementIndices->asMutable(); + + memcpy(rawNewSizes, sizes, row * sizeof(vector_size_t)); + std::iota(rawElementIndices, rawElementIndices + numElements, 0); + } + } else if (rawNewSizes) { + rawElementIndices[offset + i - duplicateCnt] = offset + i; + } + } + if (rawNewSizes) { + rawNewSizes[row] = size - duplicateCnt; + } + }; + + if (rawNewSizes) { + return std::make_shared( + mapVector->pool(), + mapVector->type(), + mapVector->nulls(), + mapVector->size(), + mapVector->offsets(), + newSizes, + BaseVector::wrapInDictionary( + BufferPtr(nullptr), elementIndices, numElements, mapKeys), + BaseVector::wrapInDictionary( + BufferPtr(nullptr), + elementIndices, + numElements, + mapVector->mapValues())); + } else { + return mapVector; + } +} + +void MapAggregateBase::addMapInputToAccumulator( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) { + decodedMaps_.decode(*args[0], rows); + auto mapVector = decodedMaps_.base()->as(); + + VELOX_CHECK_NOT_NULL(mapVector); + auto& mapKeys = mapVector->mapKeys(); + auto& mapValues = mapVector->mapValues(); + rows.applyToSelected([&](vector_size_t row) { + auto group = groups[row]; + auto accumulator = value(group); + + auto decodedRow = decodedMaps_.index(row); + auto offset = mapVector->offsetAt(decodedRow); + auto size = mapVector->sizeAt(decodedRow); + auto tracker = trackRowSize(group); + accumulator->keys.appendRange(mapKeys, offset, size, allocator_); + accumulator->values.appendRange(mapValues, offset, size, allocator_); + }); +} + +void MapAggregateBase::addSingleGroupMapInputToAccumulator( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) { + decodedMaps_.decode(*args[0], rows); + auto mapVector = decodedMaps_.base()->as(); + + auto accumulator = value(group); + auto& keys = accumulator->keys; + auto& values = accumulator->values; + + VELOX_CHECK_NOT_NULL(mapVector); + auto& mapKeys = mapVector->mapKeys(); + auto& mapValues = mapVector->mapValues(); + rows.applyToSelected([&](vector_size_t row) { + auto decodedRow = decodedMaps_.index(row); + auto offset = mapVector->offsetAt(decodedRow); + auto size = mapVector->sizeAt(decodedRow); + keys.appendRange(mapKeys, offset, size, allocator_); + values.appendRange(mapValues, offset, size, allocator_); + }); +} +} // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/MapAggregateBase.h b/velox/functions/prestosql/aggregates/MapAggregateBase.h new file mode 100644 index 000000000000..7e57c849426f --- /dev/null +++ b/velox/functions/prestosql/aggregates/MapAggregateBase.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/ContainerRowSerde.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/ValueList.h" + +namespace facebook::velox::aggregate { +struct MapAccumulator { + ValueList keys; + ValueList values; +}; + +class MapAggregateBase : public exec::Aggregate { + public: + explicit MapAggregateBase(TypePtr resultType) : Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(MapAccumulator); + } + + bool isFixedSize() const override { + return false; + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + for (auto index : indices) { + new (groups[index] + offset_) MapAccumulator(); + } + } + + void finalize(char** groups, int32_t numGroups) override { + for (auto i = 0; i < numGroups; i++) { + value(groups[i])->keys.finalize(allocator_); + value(groups[i])->values.finalize(allocator_); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override; + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + extractValues(groups, numGroups, result); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + addMapInputToAccumulator(groups, rows, args, false); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + addSingleGroupMapInputToAccumulator(group, rows, args, false); + } + + void destroy(folly::Range groups) override { + for (auto group : groups) { + auto accumulator = value(group); + accumulator->keys.free(allocator_); + accumulator->values.free(allocator_); + } + } + + protected: + vector_size_t countElements(char** groups, int32_t numGroups) const { + vector_size_t size = 0; + for (int32_t i = 0; i < numGroups; ++i) { + size += value(groups[i])->keys.size(); + } + return size; + } + + VectorPtr removeDuplicates(MapVectorPtr& mapVector) const; + + void addMapInputToAccumulator( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown); + + void addSingleGroupMapInputToAccumulator( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown); + + DecodedVector decodedKeys_; + DecodedVector decodedValues_; + DecodedVector decodedMaps_; +}; +} // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/MapUnionAggregate.cpp b/velox/functions/prestosql/aggregates/MapUnionAggregate.cpp new file mode 100644 index 000000000000..273e3cf721d8 --- /dev/null +++ b/velox/functions/prestosql/aggregates/MapUnionAggregate.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/aggregates/MapAggregateBase.h" + +namespace facebook::velox::aggregate { +namespace { +// See documentation at +// https://prestodb.io/docs/current/functions/aggregate.html +class MapUnionAggregate : public aggregate::MapAggregateBase { + public: + explicit MapUnionAggregate(TypePtr resultType) + : MapAggregateBase(resultType) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + addMapInputToAccumulator(groups, rows, args, false); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + addSingleGroupMapInputToAccumulator(group, rows, args, false); + } +}; + +bool registerMapUnionAggregate(const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .typeVariable("K") + .typeVariable("V") + .returnType("map(K,V)") + .intermediateType("map(K,V)") + .argumentType("map(K,V)") + .build()}; + + exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType) -> std::unique_ptr { + VELOX_CHECK_EQ( + argTypes.size(), + 1, + "{} ({}): unexpected number of arguments", + name); + return std::make_unique(resultType); + }); + return true; +} + +static bool FB_ANONYMOUS_VARIABLE(g_AggregateFunction) = + registerMapUnionAggregate(kMapUnion); +} // namespace +} // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt index 7e3265d829a4..83ed096f7eac 100644 --- a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable( PrestoHasherTest.cpp SumTest.cpp MapAggTest.cpp + MapUnionAggregationTest.cpp ValueListTest.cpp VarianceAggregationTest.cpp) diff --git a/velox/functions/prestosql/aggregates/tests/MapUnionAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/MapUnionAggregationTest.cpp new file mode 100644 index 000000000000..211812710a2a --- /dev/null +++ b/velox/functions/prestosql/aggregates/tests/MapUnionAggregationTest.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/aggregates/tests/AggregationTestBase.h" + +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::aggregate::test { + +namespace { + +class MapUnionTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + // Tests only single batches of input. + disableSpill(); + } +}; + +/** + * This test checks single, partial, intermediate, final aggregates + * with and without local exchange when + * there are no duplicates in the keys of the map. + * + * Takes as input a table that contains 2 columns. + * First column: five 0's followed by five 1's. + * Second column: A map of size one which contains + * consecutive numbers, where the key is NULL for + * every 4th entry (num % 4 == 0) and the value is NULL + * for every 7th entry (num % 7 == 0). + * + * The expected output is to GROUP BY the first column + * and the map size is to be of size 3 for the first row + * while of size 4 for the second row, where each map + * has a list of consecutive numbers and the value of the + * map is NULL for every 7th entry (num % 7 == 0). + */ +TEST_F(MapUnionTest, groupByWithoutDuplicates) { + auto inputVectors = {makeRowVector( + {makeFlatVector(10, [](vector_size_t row) { return row / 5; }), + makeMapVector( + 10, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t row) { return row; }, + [&](vector_size_t row) { return row + 0.05; }, + nullEvery(4), + nullEvery(7))})}; + + auto expectedResult = {makeRowVector( + {makeFlatVector({0, 1}), + makeMapVector( + 2, + [&](vector_size_t row) { return row == 0 ? 3 : 4; }, + [&](vector_size_t row) { return row; }, + [&](vector_size_t row) { return row + 0.05; }, + nullptr, + nullEvery(7))})}; + + testAggregations(inputVectors, {"c0"}, {"map_union(c1)"}, expectedResult); +} + +/** + * This test checks single, partial, intermediate, final aggregates + * with and without local exchange when + * there are duplicates in the keys of the map. + * + * Takes as input a table that contains 2 columns. + * First column: five 0's followed by five 1's + * Second column: A map of size one which contains + * (Key, Value) as (1, 1.05). + * + * The expected output is to GROUP BY the first column + * and for each row, the map size is to be of size 1 which + * contains (Key, Value) as (1, 1.05). + */ +TEST_F(MapUnionTest, groupByWithDuplicates) { + auto inputVectors = {makeRowVector( + {makeFlatVector(10, [](vector_size_t row) { return row / 5; }), + makeMapVector( + 10, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1.05; })})}; + auto expectedResult = {makeRowVector( + {makeFlatVector({0, 1}), + makeMapVector( + 2, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1.05; })})}; + + testAggregations(inputVectors, {"c0"}, {"map_union(c1)"}, expectedResult); +} + +/** + * This test checks single, partial, intermediate, final aggregates + * with and without local exchange when input is empty. + */ +TEST_F(MapUnionTest, groupByNoData) { + auto inputVectors = {makeRowVector( + {makeFlatVector({}), makeMapVector({})})}; + auto expectedResult = inputVectors; + + testAggregations(inputVectors, {"c0"}, {"map_union(c1)"}, expectedResult); +} + +/** + * This test checks global aggregate when + * with and without local exchange when + * there are no duplicates in the keys of the map. + * + * Takes as input a table that contains 1 column i.e. + * a map of size one which contains consecutive numbers, + * where the key is NULL for every 4th entry + * (num % 4 == 0) and the value is NULL + * for every 7th entry (num % 7 == 0). + * + * The expected output is a map of all the non-NULL keys. + */ +TEST_F(MapUnionTest, globalWithoutDuplicates) { + auto inputVectors = {makeRowVector({makeMapVector( + 10, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t row) { return row; }, + [&](vector_size_t row) { return row + 0.05; }, + nullEvery(4), + nullEvery(7))})}; + auto expectedResult = {makeRowVector({makeMapVector( + 1, + [&](vector_size_t /*row*/) { return 7; }, + [&](vector_size_t row) { return row; }, + [&](vector_size_t row) { return row + 0.05; }, + nullptr, + nullEvery(7))})}; + + testAggregations(inputVectors, {}, {"map_union(c0)"}, expectedResult); +} + +/** + * This test checks global aggregate when + * with and without local exchange when + * there are duplicates in the keys of the map. + * + * Takes as input a table that contains 1 column i.e. + * a map of size one which contains + * (Key, Value) as (1, 1.05). + * + * The expected output a map which + * contains (Key, Value) as (1, 1.05). + */ +TEST_F(MapUnionTest, globalWithDuplicates) { + auto inputVectors = {makeRowVector({makeMapVector( + 10, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1.05; })})}; + auto expectedResult = {makeRowVector({makeMapVector( + 1, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1; }, + [&](vector_size_t /*row*/) { return 1.05; })})}; + + testAggregations(inputVectors, {}, {"map_union(c0)"}, expectedResult); +} + +/** + * This test checks global aggregate when + * the input is empty. + */ +TEST_F(MapUnionTest, globalNoData) { + auto inputVectors = {makeRowVector({makeMapVector( + 1, + [&](vector_size_t /*row*/) { return 0; }, + [&](vector_size_t /*row*/) { return 0; }, + [&](vector_size_t /*row*/) { return 0; })})}; + auto expectedResult = inputVectors; + + testAggregations(inputVectors, {}, {"map_union(c0)"}, expectedResult); +} +} // namespace +} // namespace facebook::velox::aggregate::test