Skip to content

Commit

Permalink
Pass input types to simple function's 'initialize' method (#9124)
Browse files Browse the repository at this point in the history
Summary:

This change is part of enabling simple functions to process inputs of decimal type. Such processing requires access to decimal type parameters (precision and scale). This change provides full type information via 'initialize' method.

See #9096 for the end-to-end workflow.

Reviewed By: xiaoxmeng

Differential Revision: D55012189
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Mar 19, 2024
1 parent ac7f58f commit cd6ea36
Show file tree
Hide file tree
Showing 20 changed files with 82 additions and 15 deletions.
16 changes: 15 additions & 1 deletion velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,22 @@ class UDFHolder final
Fun,
initialize_method_resolver,
void,
const std::vector<TypePtr>&,
const core::QueryConfig&,
const exec_arg_type<TArgs>*...>::value;

// TODO Remove
static constexpr bool udf_has_legacy_initialize = util::has_method<
Fun,
initialize_method_resolver,
void,
const core::QueryConfig&,
const exec_arg_type<TArgs>*...>::value;

static_assert(
!udf_has_legacy_initialize,
"Legacy initialize method! Upgrade.");

static_assert(
udf_has_call || udf_has_callNullable || udf_has_callNullFree,
"UDF must implement at least one of `call`, `callNullable`, or `callNullFree` functions.\n"
Expand Down Expand Up @@ -706,10 +719,11 @@ class UDFHolder final
explicit UDFHolder() : Metadata(), instance_{} {}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const typename exec_resolver<TArgs>::in_type*... constantArgs) {
if constexpr (udf_has_initialize) {
return instance_.initialize(config, constantArgs...);
return instance_.initialize(inputTypes, config, constantArgs...);
}
}

Expand Down
2 changes: 2 additions & 0 deletions velox/docs/develop/scalar-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ properties and using it when processing inputs.
const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down Expand Up @@ -365,6 +366,7 @@ individual rows.
std::optional<DateTimeUnit> unit_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp*/) {
Expand Down
1 change: 1 addition & 0 deletions velox/examples/SimpleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ struct MyRegexpMatchFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig&,
const arg_type<Varchar>*,
const arg_type<Varchar>* pattern) {
Expand Down
12 changes: 7 additions & 5 deletions velox/expression/SimpleFunctionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,36 +205,38 @@ class SimpleFunctionAdapter : public VectorFunction {

template <int32_t POSITION, typename... Values>
void unpackInitialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const std::vector<VectorPtr>& packed,
const Values*... values) const {
if constexpr (POSITION == FUNC::num_args) {
return (*fn_).initialize(config, values...);
return (*fn_).initialize(inputTypes, config, values...);
} else {
if (packed.at(POSITION) != nullptr) {
SelectivityVector rows(1);
DecodedVector decodedVector(*packed.at(POSITION), rows);
auto oneReader = VectorReader<arg_at<POSITION>>(&decodedVector);
auto oneValue = oneReader[0];

unpackInitialize<POSITION + 1>(config, packed, values..., &oneValue);
unpackInitialize<POSITION + 1>(
inputTypes, config, packed, values..., &oneValue);
} else {
using temp_type = exec_arg_at<POSITION>;
unpackInitialize<POSITION + 1>(
config, packed, values..., (const temp_type*)nullptr);
inputTypes, config, packed, values..., (const temp_type*)nullptr);
}
}
}

public:
SimpleFunctionAdapter(
const std::vector<TypePtr>& /*inputTypes*/,
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const std::vector<VectorPtr>& constantInputs)
: fn_{std::make_unique<FUNC>()} {
if constexpr (FUNC::udf_has_initialize) {
try {
unpackInitialize<0>(config, constantInputs);
unpackInitialize<0>(inputTypes, config, constantInputs);
} catch (const VeloxRuntimeError&) {
throw;
} catch (const std::exception&) {
Expand Down
3 changes: 3 additions & 0 deletions velox/expression/tests/SimpleFunctionInitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct NonDefaultWithArrayInitFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<velox::Array<int32_t>>* second) {
Expand Down Expand Up @@ -131,6 +132,7 @@ struct NonDefaultWithMapInitFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<velox::Map<int32_t, int64_t>>* second) {
Expand Down Expand Up @@ -200,6 +202,7 @@ struct InitAlwaysThrowsFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/) {
VELOX_USER_FAIL("Unconditional throw!");
Expand Down
1 change: 1 addition & 0 deletions velox/expression/tests/SimpleFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ struct ConstantArgumentFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<int32_t>* /*second*/,
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/lib/Re2Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ struct Re2RegexpReplace {
std::optional<RE2> re_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*string*/,
const arg_type<Varchar>* pattern,
Expand All @@ -282,12 +283,13 @@ struct Re2RegexpReplace {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* string,
const arg_type<Varchar>* pattern) {
StringView emptyReplacement;

initialize(config, string, pattern, &emptyReplacement);
initialize(inputTypes, config, string, pattern, &emptyReplacement);
}

FOLLY_ALWAYS_INLINE bool call(
Expand Down
1 change: 1 addition & 0 deletions velox/functions/lib/TimeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct InitSessionTimezone {
const date::time_zone* timeZone_{nullptr};

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down
20 changes: 19 additions & 1 deletion velox/functions/prestosql/DateTimeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,21 @@ struct DateFunction : public TimestampWithTimezoneSupport<T> {
const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* date) {
timeZone_ = getTimeZoneFromConfig(config);
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* timestamp) {
timeZone_ = getTimeZoneFromConfig(config);
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<TimestampWithTimezone>* timestampWithTimezone) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down Expand Up @@ -820,6 +823,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp*/) {
Expand All @@ -831,6 +835,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<Date>* /*date*/) {
Expand All @@ -840,6 +845,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<TimestampWithTimezone>* /*timestamp*/) {
Expand Down Expand Up @@ -997,6 +1003,7 @@ struct DateAddFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_ = std::nullopt;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const int64_t* /*value*/,
Expand All @@ -1008,6 +1015,7 @@ struct DateAddFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const int64_t* /*value*/,
Expand Down Expand Up @@ -1103,6 +1111,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_ = std::nullopt;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp1*/,
Expand All @@ -1115,6 +1124,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<Date>* /*date1*/,
Expand All @@ -1125,6 +1135,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<TimestampWithTimezone>* /*timestampWithTimezone1*/,
Expand Down Expand Up @@ -1195,6 +1206,7 @@ struct DateFormatFunction : public TimestampWithTimezoneSupport<T> {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand All @@ -1206,6 +1218,7 @@ struct DateFormatFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<TimestampWithTimezone>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1271,6 +1284,7 @@ struct DateParseFunction {
bool isConstFormat_ = false;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1313,6 +1327,7 @@ struct FormatDateTimeFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1383,6 +1398,7 @@ struct ParseDateTimeFunction {
bool isConstFormat_ = false;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -1427,7 +1443,9 @@ struct CurrentDateFunction {

const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(const core::QueryConfig& config) {
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config) {
timeZone_ = getTimeZoneFromConfig(config);
}

Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/HyperLogLogFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct EmptyApproxSetWithMaxErrorFunction {
std::string serialized_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const double* maxStandardError) {
VELOX_USER_CHECK_NOT_NULL(
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/prestosql/MapSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace facebook::velox::functions {
VELOX_DEFINE_FUNCTION_TYPES(TExec); \
\
void initialize( \
const std::vector<TypePtr>& /*inputTypes*/, \
const core::QueryConfig& /*config*/, \
const arg_type<Map<TType, Generic<T1>>>* /*inputMap*/, \
const arg_type<Array<TType>>* keys) { \
Expand Down Expand Up @@ -98,6 +99,7 @@ struct MapSubsetVarcharFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Map<Varchar, Generic<T1>>>* /*inputMap*/,
const arg_type<Array<Varchar>>* keys) {
Expand Down
7 changes: 7 additions & 0 deletions velox/functions/sparksql/DateTimeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ struct UnixTimestampParseFunction {
// unix_timestamp(input);
// If format is not specified, assume kDefaultFormat.
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/) {
format_ = buildJodaDateTimeFormatter(kDefaultFormat_);
Expand Down Expand Up @@ -173,6 +174,7 @@ struct UnixTimestampParseWithFormatFunction
// unix_timestamp(input, format):
// If format is constant, compile it just once per batch.
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -227,6 +229,7 @@ struct FromUnixtimeFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<int64_t>* /*unixtime*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -270,6 +273,7 @@ struct ToUtcTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* timezone) {
Expand Down Expand Up @@ -299,6 +303,7 @@ struct FromUtcTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* timezone) {
Expand Down Expand Up @@ -329,6 +334,7 @@ struct GetTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -608,6 +614,7 @@ struct NextDayFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Date>* /*startDate*/,
const arg_type<Varchar>* dayOfWeek) {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/In.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct InFunctionOuter {
VELOX_DEFINE_FUNCTION_TYPES(TExecCtx);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<TInput>* /*searchTerm*/,
const arg_type<velox::Array<TInput>>* searchElements) {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct BloomFilterMightContainFunction {
using Allocator = std::allocator<uint64_t>;

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig&,
const arg_type<Varbinary>* serialized,
const arg_type<int64_t>*) {
Expand Down
Loading

0 comments on commit cd6ea36

Please sign in to comment.