Skip to content

Commit

Permalink
Add support for DECIMAL types to Simple Function API (#9096)
Browse files Browse the repository at this point in the history
Summary:
**type/Type.h**

	Add P1, P2, P3, P4, S1, S2, S3, S4 types to specify precision and scale parameters for decimal types during function registration.

	Add LongDecimal<P, S> and ShortDecimal<P, S> templates to specify decimal argument and return types during function registration.

```
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);
```

**expression/UdfTypeResolver.h**

Define arg_type and out_type for LongDecimal and ShortDecimal.

```
	arg_type<LongDecimal> = int128_t
	out_type<LongDecimal> = int128_t

	arg_type<ShortDecimal> = int64_t
	out_type<ShortDecimal> = int64_t
```

**functions/Registerer.h**

	Add optional ‘constraints’ parameter to registerFunction template. This allows to specify rules for calculating precision and scale for decimal return types.

```
template <template <class> typename Func, typename TReturn, typename... TArgs>
void registerFunction(
    const std::vector<std::string>& aliases = {},
    const std::vector<exec::SignatureVariable>& constraints = {})
```

	Here is how we can specify calculation of precision and scale for the return type of plus(decimal, decimal).

```
  std::vector<exec::SignatureVariable> constraints = {
      exec::SignatureVariable(
          P3::name(),
          fmt::format(
              "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)",
              fmt::arg("a_precision", P1::name()),
              fmt::arg("b_precision", P2::name()),
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
      exec::SignatureVariable(
          S3::name(),
          fmt::format(
              "max({a_scale}, {b_scale})",
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
  };
```

**core/SimpleFunctionMetadata.h**

Extend SimpleFunctionMetadata to store physical types (TypeKind) of input arguments and return type in addition to signature.

Decimal “plus” function has a single signature: 

```
(decimal(p1, s1), decimal(p2, s2)) -> decimal(p3, s3)
```

But 5 implementations:

```
	(int64_t, int64_t) -> int64_t
	(int64_t, int64_t) -> int128_t
	(int64_t, int128_t) -> int128_t
	(int128_t, int64_t) -> int128_t
	(int128_t, int128_t) -> int128_t
```

We need a way to distinguish between these.

**expression/SimpleFunctionRegistry.h/cpp**

Allow for storing multiple implementations for a single signature.

```
using SignatureMap = std::unordered_map<
    FunctionSignature,
    std::vector<std::unique_ptr<const FunctionEntry>>>;
using FunctionMap = std::unordered_map<std::string, SignatureMap>;
```

Modify SimpleFunctionRegistry::resolveFunction method to find an implementation with matching signature and matching TypeKinds for arguments and return type.

**core/SimpleFunctionMetadata.h**

Introduce optional initializeTypes method for a function to receive input types. Functions that operate on decimal types use this method to get access to precision and scale of the arguments.

```
void initializeTypes(const std::vector<TypePtr>& argTypes)
```

**Example: Decimal ADD**

Here is how a function that adds 2 decimal numbers can be defined. This function supports adding decimal numbers with possibly different precision and scale.

```
template <typename TExec>
struct DecimalAddFunction {
  VELOX_DEFINE_FUNCTION_TYPES(TExec);

  void initializeTypes(const std::vector<TypePtr>& argTypes) {
    auto aType = argTypes[0];
    auto bType = argTypes[1];
    auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
    auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
    auto [rPrecision, rScale] = Addition::computeResultPrecisionScale(
        aPrecision, aScale, bPrecision, bScale);
    aRescale_ = Addition::computeRescaleFactor(aScale, bScale, rScale);
    bRescale_ = Addition::computeRescaleFactor(bScale, aScale, rScale);
  }

  template <typename R, typename A, typename B>
  void call(R& out, const A& a, const B& b) {
    Addition::template apply<R, A, B>(out, a, b, aRescale_, bRescale_);
  }

 private:
  uint8_t aRescale_;
  uint8_t bRescale_;
};
```

The registration involves specifying a rule for calculating precision and scale for the result based on precision and scale of the inputs and provides 5 implementations with all possible permutations of short and long decimals in the input and result.

```
  std::vector<exec::SignatureVariable> constraints = {
      exec::SignatureVariable(
          P3::name(),
          fmt::format(
              "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)",
              fmt::arg("a_precision", P1::name()),
              fmt::arg("b_precision", P2::name()),
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
      exec::SignatureVariable(
          S3::name(),
          fmt::format(
              "max({a_scale}, {b_scale})",
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
  };

  // (long, long) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);

  // (short, short) -> short
  registerFunction<
      DecimalAddFunction,
      ShortDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);

  // (short, short) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);

  // (short, long) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);

  // (long, short) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);
```


Differential Revision: D54953663
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Mar 18, 2024
1 parent 901662f commit 82dc279
Show file tree
Hide file tree
Showing 12 changed files with 463 additions and 30 deletions.
95 changes: 90 additions & 5 deletions velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,36 @@ struct TypeAnalysis<Generic<T, comparable, orderable>> {
}
};

template <typename P, typename S>
struct TypeAnalysis<ShortDecimal<P, S>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;

const auto p = P::name();
const auto s = S::name();
results.out << fmt::format("decimal({},{})", p, s);
results.addVariable(exec::SignatureVariable(
p, std::nullopt, exec::ParameterType::kIntegerParameter));
results.addVariable(exec::SignatureVariable(
s, std::nullopt, exec::ParameterType::kIntegerParameter));
}
};

template <typename P, typename S>
struct TypeAnalysis<LongDecimal<P, S>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;

const auto p = P::name();
const auto s = S::name();
results.out << fmt::format("decimal({},{})", p, s);
results.addVariable(exec::SignatureVariable(
p, std::nullopt, exec::ParameterType::kIntegerParameter));
results.addVariable(exec::SignatureVariable(
s, std::nullopt, exec::ParameterType::kIntegerParameter));
}
};

template <typename K, typename V>
struct TypeAnalysis<Map<K, V>> {
void run(TypeAnalysisResults& results) {
Expand Down Expand Up @@ -329,6 +359,8 @@ class ISimpleFunctionMetadata {
virtual bool isDeterministic() const = 0;
virtual uint32_t priority() const = 0;
virtual const std::shared_ptr<exec::FunctionSignature> signature() const = 0;
virtual TypeKind resultTypeKind() const = 0;
virtual const std::vector<TypeKind>& argTypeKinds() const = 0;
virtual std::string helpMessage(const std::string& name) const = 0;
virtual ~ISimpleFunctionMetadata() = default;
};
Expand Down Expand Up @@ -407,10 +439,14 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
}
}

explicit SimpleFunctionMetadata() {
auto analysis = analyzeSignatureTypes();
explicit SimpleFunctionMetadata(
const std::vector<exec::SignatureVariable>& constraints) {
auto analysis = analyzeSignatureTypes(constraints);

buildSignature(analysis);
priority_ = analysis.stats.computePriority();
resultTypeKind_ = analysis.resultTypeKind;
argTypeKinds_ = analysis.argTypeKinds;
}

~SimpleFunctionMetadata() override = default;
Expand All @@ -419,6 +455,14 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
return signature_;
}

TypeKind resultTypeKind() const override {
return resultTypeKind_;
}

const std::vector<TypeKind>& argTypeKinds() const override {
return argTypeKinds_;
}

std::string helpMessage(const std::string& name) const final {
// return fmt::format("{}({})", name, signature_->toString());
std::string s{name};
Expand Down Expand Up @@ -446,30 +490,52 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
std::string outputType;
std::map<std::string, exec::SignatureVariable> variables;
TypeAnalysisResults::Stats stats;
TypeKind resultTypeKind;
std::vector<TypeKind> argTypeKinds;
};

SignatureTypesAnalysisResults analyzeSignatureTypes() {
SignatureTypesAnalysisResults analyzeSignatureTypes(
const std::vector<exec::SignatureVariable>& constraints) {
std::vector<std::string> argsTypes;

TypeAnalysisResults results;
TypeAnalysis<return_type>().run(results);
std::string outputType = results.typeAsString();

auto resultTypeKind = SimpleTypeTrait<return_type>::typeKind;
std::vector<TypeKind> argTypeKinds;

(
[&]() {
// Clear string representation but keep other collected information
// to accumulate.
results.resetTypeString();
TypeAnalysis<Args>().run(results);
argsTypes.push_back(results.typeAsString());

if constexpr (!isVariadicType<Args>::value) {
argTypeKinds.push_back(SimpleTypeTrait<Args>::typeKind);
}
}(),
...);

for (const auto& constraint : constraints) {
VELOX_CHECK(
!constraint.constraint().empty(),
"Constraint must be set for variable {}",
constraint.name());

results.variablesInformation.erase(constraint.name());
results.variablesInformation.emplace(constraint.name(), constraint);
}

return SignatureTypesAnalysisResults{
std::move(argsTypes),
std::move(outputType),
std::move(results.variablesInformation),
std::move(results.stats)};
std::move(results.stats),
resultTypeKind,
argTypeKinds};
}

void buildSignature(const SignatureTypesAnalysisResults& analysis) {
Expand Down Expand Up @@ -497,6 +563,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {

exec::FunctionSignaturePtr signature_;
uint32_t priority_;
TypeKind resultTypeKind_;
std::vector<TypeKind> argTypeKinds_;
};

// wraps a UDF object to provide the inheritance
Expand Down Expand Up @@ -544,6 +612,7 @@ class UDFHolder final
DECLARE_METHOD_RESOLVER(callNullFree_method_resolver, callNullFree);
DECLARE_METHOD_RESOLVER(callAscii_method_resolver, callAscii);
DECLARE_METHOD_RESOLVER(initialize_method_resolver, initialize);
DECLARE_METHOD_RESOLVER(initializeTypes_method_resolver, initializeTypes);

// Check which flavor of the call() method is provided by the UDF object. UDFs
// are required to provide at least one of the following methods:
Expand Down Expand Up @@ -650,6 +719,13 @@ class UDFHolder final
const core::QueryConfig&,
const exec_arg_type<TArgs>*...>::value;

// initializeTypes():
static constexpr bool udf_has_initializeTypes = util::has_method<
Fun,
initializeTypes_method_resolver,
void,
const std::vector<TypePtr>&>::value;

static_assert(
udf_has_call || udf_has_callNullable || udf_has_callNullFree,
"UDF must implement at least one of `call`, `callNullable`, or `callNullFree` functions.\n"
Expand Down Expand Up @@ -703,7 +779,9 @@ class UDFHolder final
template <size_t N>
using exec_type_at = typename std::tuple_element<N, exec_arg_types>::type;

explicit UDFHolder() : Metadata(), instance_{} {}
explicit UDFHolder(
const std::vector<exec::SignatureVariable>& constraints = {})
: Metadata(constraints), instance_{} {}

FOLLY_ALWAYS_INLINE void initialize(
const core::QueryConfig& config,
Expand All @@ -713,6 +791,13 @@ class UDFHolder final
}
}

FOLLY_ALWAYS_INLINE void initializeTypes(
const std::vector<TypePtr>& argTypes) {
if constexpr (udf_has_initializeTypes) {
return instance_.initializeTypes(argTypes);
}
}

FOLLY_ALWAYS_INLINE bool call(
exec_return_type& out,
const typename exec_resolver<TArgs>::in_type&... args) {
Expand Down
2 changes: 1 addition & 1 deletion velox/expression/ExprCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ ExprPtr compileRewrittenExpression(
resultType,
folly::join(", ", inputTypes));
auto func_2 = simpleFunctionEntry->createFunction()->createVectorFunction(
getConstantInputs(compiledInputs), config);
inputTypes, getConstantInputs(compiledInputs), config);
result = std::make_shared<Expr>(
resultType,
std::move(compiledInputs),
Expand Down
10 changes: 8 additions & 2 deletions velox/expression/SimpleFunctionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ class SimpleFunctionAdapter : public VectorFunction {

public:
explicit SimpleFunctionAdapter(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const std::vector<VectorPtr>& constantInputs)
: fn_{std::make_unique<FUNC>()} {
: fn_{std::make_unique<FUNC>(std::vector<exec::SignatureVariable>{})} {
if constexpr (FUNC::udf_has_initialize) {
try {
unpackInitialize<0>(config, constantInputs);
Expand All @@ -240,6 +241,10 @@ class SimpleFunctionAdapter : public VectorFunction {
initializeException_ = std::current_exception();
}
}

if constexpr (FUNC::udf_has_initializeTypes) {
(*fn_).initializeTypes(inputTypes);
}
}

explicit SimpleFunctionAdapter() {}
Expand Down Expand Up @@ -901,10 +906,11 @@ class SimpleFunctionAdapterFactoryImpl : public SimpleFunctionAdapterFactory {
explicit SimpleFunctionAdapterFactoryImpl() {}

std::unique_ptr<VectorFunction> createVectorFunction(
const std::vector<TypePtr>& inputTypes,
const std::vector<VectorPtr>& constantInputs,
const core::QueryConfig& config) const override {
return std::make_unique<SimpleFunctionAdapter<UDFHolder>>(
config, constantInputs);
inputTypes, config, constantInputs);
}
};

Expand Down
53 changes: 45 additions & 8 deletions velox/expression/SimpleFunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ void SimpleFunctionRegistry::registerFunctionInternal(
const auto sanitizedName = sanitizeName(name);
registeredFunctions_.withWLock([&](auto& map) {
SignatureMap& signatureMap = map[sanitizedName];
signatureMap[*metadata->signature()] =
std::make_unique<const FunctionEntry>(metadata, factory);
// TODO Check type kinds and avoid adding duplicates.
signatureMap[*metadata->signature()].emplace_back(
std::make_unique<const FunctionEntry>(metadata, factory));
});
}

Expand Down Expand Up @@ -83,12 +84,48 @@ SimpleFunctionRegistry::resolveFunction(
for (const auto& [candidateSignature, functionEntry] : *signatureMap) {
SignatureBinder binder(candidateSignature, argTypes);
if (binder.tryBind()) {
auto* currentCandidate = functionEntry.get();
if (!selectedCandidate ||
currentCandidate->getMetadata().priority() <
selectedCandidate->getMetadata().priority()) {
selectedCandidate = currentCandidate;
selectedCandidateType = binder.tryResolveReturnType();
for (const auto& currentCandidate : functionEntry) {
const auto& m = currentCandidate->getMetadata();

// Check that TypeKinds of arguments match.
bool match = true;
for (auto i = 0; i < m.argTypeKinds().size(); ++i) {
const auto typeKind = m.argTypeKinds()[i];

// Generic types do not specify TypeKind. Skip the check.
if (typeKind == TypeKind::UNKNOWN) {
continue;
}

if (argTypes[i]->kind() != typeKind) {
LOG(ERROR) << "Type mismatch for argument " << i << " of "
<< name << ": " << argTypes[i]->toString() << " vs "
<< mapTypeKindToName(typeKind);
match = false;
}
}

if (!match) {
continue;
}

if (!selectedCandidate ||
currentCandidate->getMetadata().priority() <
selectedCandidate->getMetadata().priority()) {
auto resultType = binder.tryResolveReturnType();
VELOX_CHECK_NOT_NULL(resultType);

if (m.resultTypeKind() != TypeKind::UNKNOWN &&
resultType->kind() != m.resultTypeKind()) {
LOG(ERROR) << "Type mismatch for result of " << name << ": "
<< resultType->toString() << " vs "
<< mapTypeKindToName(m.resultTypeKind());
continue;
}

selectedCandidate = currentCandidate.get();
selectedCandidateType = resultType;
}
}
}
}
Expand Down
24 changes: 16 additions & 8 deletions velox/expression/SimpleFunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
namespace facebook::velox::exec {

template <typename T>
const std::shared_ptr<const T>& singletonUdfMetadata() {
static auto instance = std::make_shared<const T>();
const std::shared_ptr<const T>& singletonUdfMetadata(
const std::vector<exec::SignatureVariable>& constraints = {}) {
static auto instance = std::make_shared<const T>(constraints);
return instance;
}

Expand All @@ -52,15 +53,19 @@ struct FunctionEntry {
const FunctionFactory factory_;
};

using SignatureMap =
std::unordered_map<FunctionSignature, std::unique_ptr<const FunctionEntry>>;
using SignatureMap = std::unordered_map<
FunctionSignature,
std::vector<std::unique_ptr<const FunctionEntry>>>;
using FunctionMap = std::unordered_map<std::string, SignatureMap>;

class SimpleFunctionRegistry {
public:
template <typename UDF>
void registerFunction(const std::vector<std::string>& aliases = {}) {
const auto& metadata = singletonUdfMetadata<typename UDF::Metadata>();
void registerFunction(
const std::vector<std::string>& aliases,
const std::vector<exec::SignatureVariable>& constraints) {
const auto& metadata =
singletonUdfMetadata<typename UDF::Metadata>(constraints);
const auto factory = [metadata]() { return CreateUdf<UDF>(); };

if (aliases.empty()) {
Expand Down Expand Up @@ -139,9 +144,12 @@ SimpleFunctionRegistry& mutableSimpleFunctions();

// This function should be called once and alone.
template <typename UDFHolder>
void registerSimpleFunction(const std::vector<std::string>& names) {
void registerSimpleFunction(
const std::vector<std::string>& names,
const std::vector<exec::SignatureVariable>& constraints) {
mutableSimpleFunctions()
.registerFunction<SimpleFunctionAdapterFactoryImpl<UDFHolder>>(names);
.registerFunction<SimpleFunctionAdapterFactoryImpl<UDFHolder>>(
names, constraints);
}

} // namespace facebook::velox::exec
14 changes: 14 additions & 0 deletions velox/expression/UdfTypeResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ struct resolver<Array<V>> {
using out_type = ArrayWriter<V>;
};

template <typename P, typename S>
struct resolver<ShortDecimal<P, S>> {
using in_type = int64_t;
using null_free_in_type = in_type;
using out_type = int64_t;
};

template <typename P, typename S>
struct resolver<LongDecimal<P, S>> {
using in_type = int128_t;
using null_free_in_type = in_type;
using out_type = int128_t;
};

template <>
struct resolver<Varchar> {
using in_type = StringView;
Expand Down
1 change: 1 addition & 0 deletions velox/expression/VectorFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class ApplyNeverCalled final : public VectorFunction {
class SimpleFunctionAdapterFactory {
public:
virtual std::unique_ptr<VectorFunction> createVectorFunction(
const std::vector<TypePtr>& inputTypes,
const std::vector<VectorPtr>& constantInputs,
const core::QueryConfig& config) const = 0;
virtual ~SimpleFunctionAdapterFactory() = default;
Expand Down
Loading

0 comments on commit 82dc279

Please sign in to comment.