Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set FunctionSignature returnType optional and refactor Spark function round decimal #10487

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions velox/expression/FunctionSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ std::string FunctionSignature::argumentsToString() const {

std::string FunctionSignature::toString() const {
std::ostringstream out;
out << "(" << argumentsToString() << ") -> " << returnType_.toString();
auto returnTypeString =
returnType_.has_value() ? returnType_.value().toString() : "";
out << "(" << argumentsToString() << ") -> " << returnTypeString;
return out.str();
}

Expand Down Expand Up @@ -143,7 +145,7 @@ void validateBaseTypeAndCollectTypeParams(

void validate(
const std::unordered_map<std::string, SignatureVariable>& variables,
const TypeSignature& returnType,
const std::optional<TypeSignature>& returnType,
const std::vector<TypeSignature>& argumentTypes,
const std::vector<bool>& constantArguments,
const std::vector<TypeSignature>& additionalTypes = {}) {
Expand All @@ -167,8 +169,10 @@ void validate(
}
}

validateBaseTypeAndCollectTypeParams(
variables, returnType, usedVariables, true);
if (returnType.has_value()) {
validateBaseTypeAndCollectTypeParams(
variables, returnType.value(), usedVariables, true);
}

VELOX_USER_CHECK_EQ(
usedVariables.size(),
Expand Down Expand Up @@ -208,7 +212,7 @@ SignatureVariable::SignatureVariable(

FunctionSignature::FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::optional<TypeSignature> returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity)
Expand All @@ -222,7 +226,7 @@ FunctionSignature::FunctionSignature(

FunctionSignature::FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
facebook::velox::exec::TypeSignature returnType,
std::optional<TypeSignature> returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity,
Expand All @@ -248,10 +252,9 @@ std::string AggregateFunctionSignature::toString() const {
}

FunctionSignaturePtr FunctionSignatureBuilder::build() {
VELOX_CHECK(returnType_.has_value());
return std::make_shared<FunctionSignature>(
std::move(variables_),
returnType_.value(),
std::move(returnType_),
std::move(argumentTypes_),
std::move(constantArguments_),
variableArity_);
Expand Down
15 changes: 10 additions & 5 deletions velox/expression/FunctionSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,16 @@ class FunctionSignature {
/// can appear zero or more times.
FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::optional<TypeSignature> returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity);

virtual ~FunctionSignature() = default;

const TypeSignature& returnType() const {
return returnType_;
VELOX_DCHECK(returnType_.has_value());
return returnType_.value();
}

const std::vector<TypeSignature>& argumentTypes() const {
Expand All @@ -148,6 +149,10 @@ class FunctionSignature {
constantArguments_.begin(), constantArguments_.end(), folly::identity);
}

bool hasReturnType() const {
return returnType_.has_value();
}

bool variableArity() const {
return variableArity_;
}
Expand Down Expand Up @@ -178,7 +183,7 @@ class FunctionSignature {
/// FunctionSignature.
FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::optional<TypeSignature> returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity,
Expand All @@ -189,7 +194,7 @@ class FunctionSignature {

private:
const std::unordered_map<std::string, SignatureVariable> variables_;
const TypeSignature returnType_;
const std::optional<TypeSignature> returnType_;
const std::vector<TypeSignature> argumentTypes_;
const std::vector<bool> constantArguments_;
const bool variableArity_;
Expand Down Expand Up @@ -318,7 +323,7 @@ class FunctionSignatureBuilder {

private:
std::unordered_map<std::string, SignatureVariable> variables_;
std::optional<TypeSignature> returnType_;
std::optional<TypeSignature> returnType_ = std::nullopt;
std::vector<TypeSignature> argumentTypes_;
std::vector<bool> constantArguments_;
bool variableArity_{false};
Expand Down
16 changes: 8 additions & 8 deletions velox/expression/fuzzer/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,14 @@ bool isSupportedSignature(
bool enableComplexType) {
// Not supporting lambda functions, or functions using decimal and
// timestamp with time zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
return !signature.hasReturnType() ||
!(useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
}

/// Returns row numbers for non-null rows among all children in'data' or null
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_library(
Comparisons.cpp
DecimalArithmetic.cpp
DecimalCompare.cpp
DecimalRound.cpp
Hash.cpp
In.cpp
LeastGreatest.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,39 @@
* limitations under the License.
*/

#include "velox/functions/sparksql/specialforms/DecimalRound.h"
#include "velox/expression/ConstantExpr.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/sparksql/DecimalUtil.h"

namespace facebook::velox::functions::sparksql {
namespace {

std::pair<uint8_t, uint8_t>
getResultPrecisionScale(uint8_t precision, uint8_t scale, int32_t roundScale) {
// After rounding we may need one more digit in the integral part,
// e.g. 'decimal_round(9.9, 0)' -> '10', 'decimal_round(99, -1)' -> '100'.
const int32_t integralLeastNumDigits = precision - scale + 1;
if (roundScale < 0) {
// Negative scale means we need to adjust `-scale` number of digits before
// the decimal point, which means we need at least `-scale + 1` digits after
// rounding, and the result scale is 0.
const auto newPrecision = std::max(
integralLeastNumDigits,
-std::max(roundScale, -(int32_t)LongDecimalType::kMaxPrecision) + 1);
// We have to accept the risk of overflow as we can't exceed the max
// precision.
return {std::min(newPrecision, (int32_t)LongDecimalType::kMaxPrecision), 0};
}
const uint8_t newScale = std::min((int32_t)scale, roundScale);
// We have to accept the risk of overflow as we cannot exceed the max
// precision.
return {
std::min(
integralLeastNumDigits + newScale,
(int32_t)LongDecimalType::kMaxPrecision),
newScale};
}

template <typename TResult, typename TInput>
class DecimalRoundFunction : public exec::VectorFunction {
public:
Expand All @@ -37,23 +64,12 @@ class DecimalRoundFunction : public exec::VectorFunction {
inputScale_(inputScale),
resultPrecision_(resultPrecision),
resultScale_(resultScale) {
const auto [p, s] = DecimalRoundCallToSpecialForm::getResultPrecisionScale(
inputPrecision, inputScale, scale);
VELOX_USER_CHECK_EQ(
p,
resultPrecision,
"The result precision of decimal_round is inconsistent with Spark expected.");
VELOX_USER_CHECK_EQ(
s,
resultScale,
"The result scale of decimal_round is inconsistent with Spark expected.");

// Decide the rescale factor of divide and multiply when rounding to a
// negative scale.
auto rescaleFactor = [&](int32_t rescale) {
VELOX_USER_CHECK_GT(
rescale, 0, "A non-negative rescale value is expected.");
return DecimalUtil::kPowersOfTen[std::min(
return velox::DecimalUtil::kPowersOfTen[std::min(
rescale, (int32_t)LongDecimalType::kMaxPrecision)];
};
if (scale_ < 0) {
Expand Down Expand Up @@ -92,19 +108,22 @@ class DecimalRoundFunction : public exec::VectorFunction {
inline TResult applyRound(const TInput& input) const {
if (scale_ >= 0) {
TResult rescaledValue;
const auto status = DecimalUtil::rescaleWithRoundUp<TInput, TResult>(
input,
inputPrecision_,
inputScale_,
resultPrecision_,
resultScale_,
rescaledValue);
const auto status =
velox::DecimalUtil::rescaleWithRoundUp<TInput, TResult>(
input,
inputPrecision_,
inputScale_,
resultPrecision_,
resultScale_,
rescaledValue);
VELOX_DCHECK(status.ok());
return rescaledValue;
} else {
TResult rescaledValue;
bool overflow;
DecimalUtil::divideWithRoundUp<TResult, TInput, int128_t>(
rescaledValue, input, divideFactor_.value(), false, 0, 0);
rescaledValue, input, divideFactor_.value(), 0, overflow);
VELOX_USER_CHECK(!overflow);
rescaledValue *= multiplyFactor_.value();
return rescaledValue;
}
Expand Down Expand Up @@ -137,24 +156,32 @@ class DecimalRoundFunction : public exec::VectorFunction {
std::optional<int128_t> multiplyFactor_ = std::nullopt;
};

std::shared_ptr<exec::VectorFunction> createDecimalRound(
const TypePtr& inputType,
int32_t scale,
const TypePtr& resultType) {
const auto [inputPrecision, inputScale] =
getDecimalPrecisionScale(*inputType);
const auto [resultPrecision, resultScale] =
getDecimalPrecisionScale(*resultType);
std::shared_ptr<exec::VectorFunction> createDecimalRoundFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
int32_t scale = 0;
if (inputArgs.size() > 1) {
VELOX_CHECK(!inputArgs[1].constantValue->isNullAt(0));
scale = inputArgs[1]
.constantValue->template as<ConstantVector<int32_t>>()
->valueAt(0);
}
const auto inputType = inputArgs[0].type;
auto [inputPrecision, inputScale] = getDecimalPrecisionScale(*inputType);
auto [resultPrecision, resultScale] =
getResultPrecisionScale(inputPrecision, inputScale, scale);

if (inputType->isShortDecimal()) {
if (resultType->isShortDecimal()) {
if (resultPrecision <= velox::ShortDecimalType::kMaxPrecision) {
return std::make_shared<DecimalRoundFunction<int64_t, int64_t>>(
scale, inputPrecision, inputScale, resultPrecision, resultScale);
} else {
return std::make_shared<DecimalRoundFunction<int128_t, int64_t>>(
scale, inputPrecision, inputScale, resultPrecision, resultScale);
}
} else {
if (resultType->isShortDecimal()) {
if (resultPrecision <= velox::ShortDecimalType::kMaxPrecision) {
return std::make_shared<DecimalRoundFunction<int64_t, int128_t>>(
scale, inputPrecision, inputScale, resultPrecision, resultScale);
} else {
Expand All @@ -163,87 +190,25 @@ std::shared_ptr<exec::VectorFunction> createDecimalRound(
}
}
}
} // namespace

std::pair<uint8_t, uint8_t>
DecimalRoundCallToSpecialForm::getResultPrecisionScale(
uint8_t precision,
uint8_t scale,
int32_t roundScale) {
// After rounding we may need one more digit in the integral part,
// e.g. 'decimal_round(9.9, 0)' -> '10', 'decimal_round(99, -1)' -> '100'.
const int32_t integralLeastNumDigits = precision - scale + 1;
if (roundScale < 0) {
// Negative scale means we need to adjust `-scale` number of digits before
// the decimal point, which means we need at least `-scale + 1` digits after
// rounding, and the result scale is 0.
const auto newPrecision = std::max(
integralLeastNumDigits,
-std::max(roundScale, -(int32_t)LongDecimalType::kMaxPrecision) + 1);
// We have to accept the risk of overflow as we can't exceed the max
// precision.
return {std::min(newPrecision, (int32_t)LongDecimalType::kMaxPrecision), 0};
}
const uint8_t newScale = std::min((int32_t)scale, roundScale);
// We have to accept the risk of overflow as we cannot exceed the max
// precision.
std::vector<std::shared_ptr<exec::FunctionSignature>> decimalSignature() {
return {
std::min(
integralLeastNumDigits + newScale,
(int32_t)LongDecimalType::kMaxPrecision),
newScale};
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.build(),
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.constantArgumentType("integer")
.build()};
}
} // namespace

TypePtr DecimalRoundCallToSpecialForm::resolveType(
const std::vector<TypePtr>& /*argTypes*/) {
VELOX_FAIL("Decimal round function does not support type resolution.");
}

exec::ExprPtr DecimalRoundCallToSpecialForm::constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK(
type->isDecimal(),
"The result type of decimal_round should be decimal type.");
VELOX_USER_CHECK_GE(
args.size(), 1, "Decimal_round expects one or two arguments.");
VELOX_USER_CHECK_LE(
args.size(), 2, "Decimal_round expects one or two arguments.");
VELOX_USER_CHECK(
args[0]->type()->isDecimal(),
"The first argument of decimal_round should be of decimal type.");

int32_t scale = 0;
if (args.size() > 1) {
VELOX_USER_CHECK_EQ(
args[1]->type()->kind(),
TypeKind::INTEGER,
"The second argument of decimal_round should be of integer type.");
auto constantExpr = std::dynamic_pointer_cast<exec::ConstantExpr>(args[1]);
VELOX_USER_CHECK_NOT_NULL(
constantExpr,
"The second argument of decimal_round should be constant expression.");
VELOX_USER_CHECK(
constantExpr->value()->isConstantEncoding(),
"The second argument of decimal_round should be wrapped in constant vector.");
auto constantVector =
constantExpr->value()->asUnchecked<ConstantVector<int32_t>>();
VELOX_USER_CHECK(
!constantVector->isNullAt(0),
"The second argument of decimal_round is non-nullable.");
scale = constantVector->valueAt(0);
}

auto decimalRound = createDecimalRound(args[0]->type(), scale, type);

return std::make_shared<exec::Expr>(
type,
std::move(args),
std::move(decimalRound),
exec::VectorFunctionMetadata{},
kRoundDecimal,
trackCpuUsage);
}
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_round,
decimalSignature(),
createDecimalRoundFunction);
} // namespace facebook::velox::functions::sparksql
Loading
Loading