Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass input types to simple function's 'initialize' method #9124

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,22 @@ class UDFHolder final
Fun,
initialize_method_resolver,
void,
const std::vector<TypePtr>&,
const core::QueryConfig&,
const exec_arg_type<TArgs>*...>::value;

// TODO Remove
static constexpr bool udf_has_legacy_initialize = util::has_method<
Fun,
initialize_method_resolver,
void,
const core::QueryConfig&,
const exec_arg_type<TArgs>*...>::value;

static_assert(
!udf_has_legacy_initialize,
"Legacy initialize method! Upgrade.");

static_assert(
udf_has_call || udf_has_callNullable || udf_has_callNullFree,
"UDF must implement at least one of `call`, `callNullable`, or `callNullFree` functions.\n"
Expand Down Expand Up @@ -706,10 +719,11 @@ class UDFHolder final
explicit UDFHolder() : Metadata(), instance_{} {}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const typename exec_resolver<TArgs>::in_type*... constantArgs) {
if constexpr (udf_has_initialize) {
return instance_.initialize(config, constantArgs...);
return instance_.initialize(inputTypes, config, constantArgs...);
}
}

Expand Down
2 changes: 2 additions & 0 deletions velox/docs/develop/scalar-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ properties and using it when processing inputs.
const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down Expand Up @@ -365,6 +366,7 @@ individual rows.
std::optional<DateTimeUnit> unit_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp*/) {
Expand Down
1 change: 1 addition & 0 deletions velox/examples/SimpleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ struct MyRegexpMatchFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig&,
const arg_type<Varchar>*,
const arg_type<Varchar>* pattern) {
Expand Down
12 changes: 7 additions & 5 deletions velox/expression/SimpleFunctionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,36 +205,38 @@ class SimpleFunctionAdapter : public VectorFunction {

template <int32_t POSITION, typename... Values>
void unpackInitialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const std::vector<VectorPtr>& packed,
const Values*... values) const {
if constexpr (POSITION == FUNC::num_args) {
return (*fn_).initialize(config, values...);
return (*fn_).initialize(inputTypes, config, values...);
} else {
if (packed.at(POSITION) != nullptr) {
SelectivityVector rows(1);
DecodedVector decodedVector(*packed.at(POSITION), rows);
auto oneReader = VectorReader<arg_at<POSITION>>(&decodedVector);
auto oneValue = oneReader[0];

unpackInitialize<POSITION + 1>(config, packed, values..., &oneValue);
unpackInitialize<POSITION + 1>(
inputTypes, config, packed, values..., &oneValue);
} else {
using temp_type = exec_arg_at<POSITION>;
unpackInitialize<POSITION + 1>(
config, packed, values..., (const temp_type*)nullptr);
inputTypes, config, packed, values..., (const temp_type*)nullptr);
}
}
}

public:
SimpleFunctionAdapter(
const std::vector<TypePtr>& /*inputTypes*/,
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const std::vector<VectorPtr>& constantInputs)
: fn_{std::make_unique<FUNC>()} {
if constexpr (FUNC::udf_has_initialize) {
try {
unpackInitialize<0>(config, constantInputs);
unpackInitialize<0>(inputTypes, config, constantInputs);
} catch (const VeloxRuntimeError&) {
throw;
} catch (const std::exception&) {
Expand Down
3 changes: 3 additions & 0 deletions velox/expression/tests/SimpleFunctionInitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct NonDefaultWithArrayInitFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<velox::Array<int32_t>>* second) {
Expand Down Expand Up @@ -131,6 +132,7 @@ struct NonDefaultWithMapInitFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<velox::Map<int32_t, int64_t>>* second) {
Expand Down Expand Up @@ -200,6 +202,7 @@ struct InitAlwaysThrowsFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/) {
VELOX_USER_FAIL("Unconditional throw!");
Expand Down
1 change: 1 addition & 0 deletions velox/expression/tests/SimpleFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ struct ConstantArgumentFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<int32_t>* /*second*/,
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/lib/Re2Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ struct Re2RegexpReplace {
std::optional<RE2> re_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*string*/,
const arg_type<Varchar>* pattern,
Expand All @@ -282,12 +283,13 @@ struct Re2RegexpReplace {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* string,
const arg_type<Varchar>* pattern) {
StringView emptyReplacement;

initialize(config, string, pattern, &emptyReplacement);
initialize(inputTypes, config, string, pattern, &emptyReplacement);
}

FOLLY_ALWAYS_INLINE bool call(
Expand Down
1 change: 1 addition & 0 deletions velox/functions/lib/TimeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct InitSessionTimezone {
const date::time_zone* timeZone_{nullptr};

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down
20 changes: 19 additions & 1 deletion velox/functions/prestosql/DateTimeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,21 @@ struct DateFunction : public TimestampWithTimezoneSupport<T> {
const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* date) {
timeZone_ = getTimeZoneFromConfig(config);
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* timestamp) {
timeZone_ = getTimeZoneFromConfig(config);
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<TimestampWithTimezone>* timestampWithTimezone) {
timeZone_ = getTimeZoneFromConfig(config);
Expand Down Expand Up @@ -820,6 +823,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp*/) {
Expand All @@ -831,6 +835,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<Date>* /*date*/) {
Expand All @@ -840,6 +845,7 @@ struct DateTruncFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<TimestampWithTimezone>* /*timestamp*/) {
Expand Down Expand Up @@ -997,6 +1003,7 @@ struct DateAddFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_ = std::nullopt;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const int64_t* /*value*/,
Expand All @@ -1008,6 +1015,7 @@ struct DateAddFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const int64_t* /*value*/,
Expand Down Expand Up @@ -1103,6 +1111,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
std::optional<DateTimeUnit> unit_ = std::nullopt;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<Timestamp>* /*timestamp1*/,
Expand All @@ -1115,6 +1124,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* unitString,
const arg_type<Date>* /*date1*/,
Expand All @@ -1125,6 +1135,7 @@ struct DateDiffFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* unitString,
const arg_type<TimestampWithTimezone>* /*timestampWithTimezone1*/,
Expand Down Expand Up @@ -1195,6 +1206,7 @@ struct DateFormatFunction : public TimestampWithTimezoneSupport<T> {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand All @@ -1206,6 +1218,7 @@ struct DateFormatFunction : public TimestampWithTimezoneSupport<T> {
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<TimestampWithTimezone>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1271,6 +1284,7 @@ struct DateParseFunction {
bool isConstFormat_ = false;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1313,6 +1327,7 @@ struct FormatDateTimeFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Timestamp>* /*timestamp*/,
const arg_type<Varchar>* formatString) {
Expand Down Expand Up @@ -1383,6 +1398,7 @@ struct ParseDateTimeFunction {
bool isConstFormat_ = false;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -1427,7 +1443,9 @@ struct CurrentDateFunction {

const date::time_zone* timeZone_ = nullptr;

FOLLY_ALWAYS_INLINE void initialize(const core::QueryConfig& config) {
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config) {
timeZone_ = getTimeZoneFromConfig(config);
}

Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/HyperLogLogFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct EmptyApproxSetWithMaxErrorFunction {
std::string serialized_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const double* maxStandardError) {
VELOX_USER_CHECK_NOT_NULL(
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/prestosql/MapSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace facebook::velox::functions {
VELOX_DEFINE_FUNCTION_TYPES(TExec); \
\
void initialize( \
const std::vector<TypePtr>& /*inputTypes*/, \
const core::QueryConfig& /*config*/, \
const arg_type<Map<TType, Generic<T1>>>* /*inputMap*/, \
const arg_type<Array<TType>>* keys) { \
Expand Down Expand Up @@ -98,6 +99,7 @@ struct MapSubsetVarcharFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Map<Varchar, Generic<T1>>>* /*inputMap*/,
const arg_type<Array<Varchar>>* keys) {
Expand Down
7 changes: 7 additions & 0 deletions velox/functions/sparksql/DateTimeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ struct UnixTimestampParseFunction {
// unix_timestamp(input);
// If format is not specified, assume kDefaultFormat.
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/) {
format_ = buildJodaDateTimeFormatter(kDefaultFormat_);
Expand Down Expand Up @@ -173,6 +174,7 @@ struct UnixTimestampParseWithFormatFunction
// unix_timestamp(input, format):
// If format is constant, compile it just once per batch.
FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -227,6 +229,7 @@ struct FromUnixtimeFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<int64_t>* /*unixtime*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -270,6 +273,7 @@ struct ToUtcTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* timezone) {
Expand Down Expand Up @@ -299,6 +303,7 @@ struct FromUtcTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* timezone) {
Expand Down Expand Up @@ -329,6 +334,7 @@ struct GetTimestampFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*input*/,
const arg_type<Varchar>* format) {
Expand Down Expand Up @@ -608,6 +614,7 @@ struct NextDayFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Date>* /*startDate*/,
const arg_type<Varchar>* dayOfWeek) {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/In.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct InFunctionOuter {
VELOX_DEFINE_FUNCTION_TYPES(TExecCtx);

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<TInput>* /*searchTerm*/,
const arg_type<velox::Array<TInput>>* searchElements) {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct BloomFilterMightContainFunction {
using Allocator = std::allocator<uint64_t>;

void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig&,
const arg_type<Varbinary>* serialized,
const arg_type<int64_t>*) {
Expand Down
Loading
Loading