Skip to content

Commit

Permalink
Add support for DECIMAL types to Simple Function API (facebookincubat…
Browse files Browse the repository at this point in the history
…or#9096)

Summary:
Use the new functionality to re-write decimal plus, minus, multiple, divide, between, negate, floor and round.

**type/Type.h**

	Add P1, P2, P3, P4, S1, S2, S3, S4 types to specify precision and scale parameters for decimal types during function registration.

	Add LongDecimal<P, S> and ShortDecimal<P, S> templates to specify decimal argument and return types during function registration.

```
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);
```

**expression/UdfTypeResolver.h**

Define arg_type and out_type for LongDecimal and ShortDecimal.

```
	arg_type<LongDecimal> = int128_t
	out_type<LongDecimal> = int128_t

	arg_type<ShortDecimal> = int64_t
	out_type<ShortDecimal> = int64_t
```

**functions/Registerer.h**

	Add optional ‘constraints’ parameter to registerFunction template. This allows to specify rules for calculating precision and scale for decimal return types.

```
template <template <class> typename Func, typename TReturn, typename... TArgs>
void registerFunction(
    const std::vector<std::string>& aliases = {},
    const std::vector<exec::SignatureVariable>& constraints = {})
```

	Here is how we can specify calculation of precision and scale for the return type of plus(decimal, decimal).

```
  std::vector<exec::SignatureVariable> constraints = {
      exec::SignatureVariable(
          P3::name(),
          fmt::format(
              "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)",
              fmt::arg("a_precision", P1::name()),
              fmt::arg("b_precision", P2::name()),
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
      exec::SignatureVariable(
          S3::name(),
          fmt::format(
              "max({a_scale}, {b_scale})",
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
  };
```

**core/SimpleFunctionMetadata.h**

Extend SimpleFunctionMetadata to store physical types (TypeKind) of input arguments and return type in addition to signature.

Decimal “plus” function has a single signature:

```
(decimal(p1, s1), decimal(p2, s2)) -> decimal(p3, s3)
```

But 5 implementations:

```
	(int64_t, int64_t) -> int64_t
	(int64_t, int64_t) -> int128_t
	(int64_t, int128_t) -> int128_t
	(int128_t, int64_t) -> int128_t
	(int128_t, int128_t) -> int128_t
```

We need a way to distinguish between these.

**expression/SimpleFunctionRegistry.h/cpp**

Allow for storing multiple implementations for a single signature.

```
using SignatureMap = std::unordered_map<
    FunctionSignature,
    std::vector<std::unique_ptr<const FunctionEntry>>>;
using FunctionMap = std::unordered_map<std::string, SignatureMap>;
```

Modify SimpleFunctionRegistry::resolveFunction method to find an implementation with matching signature and matching TypeKinds for arguments and return type.

**core/SimpleFunctionMetadata.h**

Add 'inputTypes' parameter to 'initialize' method. Functions that operate on decimal types use this parameter to get access to precision and scale of the arguments. Landed separately.

```
  void initialize(
      const std::vector<TypePtr>& inputTypes,
      const core::QueryConfig& config,
      ...)
```

**Example: Decimal Plus**

Here is how a function that adds 2 decimal numbers can be defined. This function supports adding decimal numbers with possibly different precision and scale.

```
template <typename TExec>
struct DecimalPlusFunction {
  VELOX_DEFINE_FUNCTION_TYPES(TExec);

  template <typename A, typename B>
  void initialize(
      const std::vector<TypePtr>& inputTypes,
      const core::QueryConfig& /*config*/,
      A* /*a*/,
      B* /*b*/) {
    auto aType = inputTypes[0];
    auto bType = inputTypes[1];
    auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
    auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
    auto [rPrecision, rScale] = Addition::computeResultPrecisionScale(
        aPrecision, aScale, bPrecision, bScale);
    aRescale_ = Addition::computeRescaleFactor(aScale, bScale, rScale);
    bRescale_ = Addition::computeRescaleFactor(bScale, aScale, rScale);
  }

  template <typename R, typename A, typename B>
  void call(R& out, const A& a, const B& b) {
    Addition::template apply<R, A, B>(out, a, b, aRescale_, bRescale_);
  }

 private:
  uint8_t aRescale_;
  uint8_t bRescale_;
};
```

The registration involves specifying a rule for calculating precision and scale for the result based on precision and scale of the inputs and provides 5 implementations with all possible permutations of short and long decimals in the input and result.

```
  std::vector<exec::SignatureVariable> constraints = {
      exec::SignatureVariable(
          P3::name(),
          fmt::format(
              "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)",
              fmt::arg("a_precision", P1::name()),
              fmt::arg("b_precision", P2::name()),
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
      exec::SignatureVariable(
          S3::name(),
          fmt::format(
              "max({a_scale}, {b_scale})",
              fmt::arg("a_scale", S1::name()),
              fmt::arg("b_scale", S2::name())),
          exec::ParameterType::kIntegerParameter),
  };

  // (long, long) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);

  // (short, short) -> short
  registerFunction<
      DecimalAddFunction,
      ShortDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);

  // (short, short) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);

  // (short, long) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      ShortDecimal<P1, S1>,
      LongDecimal<P2, S2>>({"plus"}, constraints);

  // (long, short) -> long
  registerFunction<
      DecimalAddFunction,
      LongDecimal<P3, S3>,
      LongDecimal<P1, S1>,
      ShortDecimal<P2, S2>>({"plus"}, constraints);
```

Pull Request resolved: facebookincubator#9096

Reviewed By: xiaoxmeng

Differential Revision: D54953663
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Mar 19, 2024
1 parent c571bcd commit 7e68798
Show file tree
Hide file tree
Showing 17 changed files with 1,135 additions and 875 deletions.
126 changes: 119 additions & 7 deletions velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,14 @@ struct TypeAnalysisResults {
}
}

// String representaion of the type in the FunctionSignatureBuilder.
/// String representation of the type in the FunctionSignatureBuilder.
std::ostringstream out;

// Set of generic variables used in the type.
/// Physical type, e.g. BIGINT() for Date and ARRAY(BIGINT()) for
// Array<Date>. UNKNOWN() if type is generic or opaque.
TypePtr physicalType;

/// Set of generic variables used in the type.
std::map<std::string, exec::SignatureVariable> variablesInformation;

std::string typeAsString() {
Expand Down Expand Up @@ -219,6 +223,13 @@ struct TypeAnalysis {
results.stats.concreteCount++;
results.out << detail::strToLowerCopy(
std::string(SimpleTypeTrait<T>::name));
if constexpr (
SimpleTypeTrait<T>::typeKind == TypeKind::OPAQUE ||
SimpleTypeTrait<T>::typeKind == TypeKind::UNKNOWN) {
results.physicalType = UNKNOWN();
} else {
results.physicalType = createScalarType(SimpleTypeTrait<T>::typeKind);
}
}
};

Expand All @@ -239,6 +250,39 @@ struct TypeAnalysis<Generic<T, comparable, orderable>> {
comparable));
}
results.stats.hasGeneric = true;
results.physicalType = UNKNOWN();
}
};

template <typename P, typename S>
struct TypeAnalysis<ShortDecimal<P, S>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;

const auto p = P::name();
const auto s = S::name();
results.out << fmt::format("decimal({},{})", p, s);
results.addVariable(exec::SignatureVariable(
p, std::nullopt, exec::ParameterType::kIntegerParameter));
results.addVariable(exec::SignatureVariable(
s, std::nullopt, exec::ParameterType::kIntegerParameter));
results.physicalType = BIGINT();
}
};

template <typename P, typename S>
struct TypeAnalysis<LongDecimal<P, S>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;

const auto p = P::name();
const auto s = S::name();
results.out << fmt::format("decimal({},{})", p, s);
results.addVariable(exec::SignatureVariable(
p, std::nullopt, exec::ParameterType::kIntegerParameter));
results.addVariable(exec::SignatureVariable(
s, std::nullopt, exec::ParameterType::kIntegerParameter));
results.physicalType = HUGEINT();
}
};

Expand All @@ -248,9 +292,12 @@ struct TypeAnalysis<Map<K, V>> {
results.stats.concreteCount++;
results.out << "map(";
TypeAnalysis<K>().run(results);
auto keyType = results.physicalType;
results.out << ", ";
TypeAnalysis<V>().run(results);
auto valueType = results.physicalType;
results.out << ")";
results.physicalType = MAP(keyType, valueType);
}
};

Expand All @@ -273,6 +320,7 @@ struct TypeAnalysis<Variadic<V>> {
results.addVariable(std::move(variable));
}
results.out << tmp.typeAsString();
results.physicalType = tmp.physicalType;
}
};

Expand All @@ -283,6 +331,7 @@ struct TypeAnalysis<Array<V>> {
results.out << "array(";
TypeAnalysis<V>().run(results);
results.out << ")";
results.physicalType = ARRAY(results.physicalType);
}
};

Expand All @@ -296,6 +345,7 @@ struct TypeAnalysis<Row<T...>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;
results.out << "row(";
std::vector<TypePtr> fieldTypes;
// This expression applies the lambda for each row child type.
bool first = true;
(
Expand All @@ -305,9 +355,11 @@ struct TypeAnalysis<Row<T...>> {
}
first = false;
TypeAnalysis<T>().run(results);
fieldTypes.push_back(results.physicalType);
}(),
...);
results.out << ")";
results.physicalType = ROW(std::move(fieldTypes));
}
};

Expand All @@ -316,20 +368,29 @@ struct TypeAnalysis<CustomType<T>> {
void run(TypeAnalysisResults& results) {
results.stats.concreteCount++;
results.out << T::typeName;

TypeAnalysisResults tmp;
TypeAnalysis<typename T::type>().run(tmp);
results.physicalType = tmp.physicalType;
}
};

class ISimpleFunctionMetadata {
public:
virtual ~ISimpleFunctionMetadata() = default;

// Return the return type of the function if its independent on the input
// types, otherwise return null.
virtual TypePtr tryResolveReturnType() const = 0;
virtual std::string getName() const = 0;
virtual bool isDeterministic() const = 0;
virtual uint32_t priority() const = 0;
virtual const std::shared_ptr<exec::FunctionSignature> signature() const = 0;
virtual const TypePtr& resultPhysicalType() const = 0;
virtual const std::vector<TypePtr>& argPhysicalTypes() const = 0;
virtual bool physicalSignatureEquals(
const ISimpleFunctionMetadata& other) const = 0;
virtual std::string helpMessage(const std::string& name) const = 0;
virtual ~ISimpleFunctionMetadata() = default;
};

template <typename T, typename = int32_t>
Expand Down Expand Up @@ -402,10 +463,14 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
}
}

explicit SimpleFunctionMetadata() {
auto analysis = analyzeSignatureTypes();
explicit SimpleFunctionMetadata(
const std::vector<exec::SignatureVariable>& constraints) {
auto analysis = analyzeSignatureTypes(constraints);

buildSignature(analysis);
priority_ = analysis.stats.computePriority();
resultPhysicalType_ = analysis.resultPhysicalType;
argPhysicalTypes_ = analysis.argPhysicalTypes;
}

~SimpleFunctionMetadata() override = default;
Expand All @@ -414,6 +479,33 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
return signature_;
}

const TypePtr& resultPhysicalType() const override {
return resultPhysicalType_;
}

const std::vector<TypePtr>& argPhysicalTypes() const override {
return argPhysicalTypes_;
}

bool physicalSignatureEquals(
const ISimpleFunctionMetadata& other) const override {
if (!resultPhysicalType_->kindEquals(other.resultPhysicalType())) {
return false;
}

if (argPhysicalTypes_.size() != other.argPhysicalTypes().size()) {
return false;
}

for (auto i = 0; i < argPhysicalTypes_.size(); ++i) {
if (!argPhysicalTypes_[i]->kindEquals(other.argPhysicalTypes()[i])) {
return false;
}
}

return true;
}

std::string helpMessage(const std::string& name) const final {
// return fmt::format("{}({})", name, signature_->toString());
std::string s{name};
Expand Down Expand Up @@ -441,14 +533,19 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
std::string outputType;
std::map<std::string, exec::SignatureVariable> variables;
TypeAnalysisResults::Stats stats;
TypePtr resultPhysicalType;
std::vector<TypePtr> argPhysicalTypes;
};

SignatureTypesAnalysisResults analyzeSignatureTypes() {
SignatureTypesAnalysisResults analyzeSignatureTypes(
const std::vector<exec::SignatureVariable>& constraints) {
std::vector<std::string> argsTypes;

TypeAnalysisResults results;
TypeAnalysis<return_type>().run(results);
std::string outputType = results.typeAsString();
const auto resultPhysicalType = results.physicalType;
std::vector<TypePtr> argPhysicalTypes;

(
[&]() {
Expand All @@ -457,14 +554,27 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
results.resetTypeString();
TypeAnalysis<Args>().run(results);
argsTypes.push_back(results.typeAsString());
argPhysicalTypes.push_back(results.physicalType);
}(),
...);

for (const auto& constraint : constraints) {
VELOX_CHECK(
!constraint.constraint().empty(),
"Constraint must be set for variable {}",
constraint.name());

results.variablesInformation.erase(constraint.name());
results.variablesInformation.emplace(constraint.name(), constraint);
}

return SignatureTypesAnalysisResults{
std::move(argsTypes),
std::move(outputType),
std::move(results.variablesInformation),
std::move(results.stats)};
std::move(results.stats),
resultPhysicalType,
argPhysicalTypes};
}

void buildSignature(const SignatureTypesAnalysisResults& analysis) {
Expand Down Expand Up @@ -492,6 +602,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {

exec::FunctionSignaturePtr signature_;
uint32_t priority_;
TypePtr resultPhysicalType_;
std::vector<TypePtr> argPhysicalTypes_;
};

// wraps a UDF object to provide the inheritance
Expand Down
Loading

0 comments on commit 7e68798

Please sign in to comment.