From 545d0eb5c04abeab6ea0357ddaf028e6994d5d29 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 7 May 2024 18:04:00 -0500 Subject: [PATCH 1/2] move covariance --- datafusion/expr/src/aggregate_function.rs | 10 +- .../expr/src/type_coercion/aggregates.rs | 10 - .../functions-aggregate/src/covariance.rs | 81 ++++ datafusion/functions-aggregate/src/lib.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 11 - .../physical-expr/src/aggregate/covariance.rs | 372 ------------------ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 14 +- .../tests/cases/roundtrip_logical_plan.rs | 3 +- 14 files changed, 93 insertions(+), 424 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index af8a682eff58..0a7607498c61 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -63,8 +63,6 @@ pub enum AggregateFunction { Stddev, /// Standard Deviation (Population) StddevPop, - /// Covariance (Population) - CovariancePop, /// Correlation Correlation, /// Slope from linear regression @@ -126,7 +124,6 @@ impl AggregateFunction { VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - CovariancePop => "COVAR_POP", Correlation => "CORR", RegrSlope => "REGR_SLOPE", RegrIntercept => "REGR_INTERCEPT", @@ -181,7 +178,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "covar_pop" => AggregateFunction::CovariancePop, "stddev" => AggregateFunction::Stddev, "stddev_pop" => AggregateFunction::StddevPop, "stddev_samp" => AggregateFunction::Stddev, @@ -255,9 +251,6 @@ impl AggregateFunction { AggregateFunction::VariancePop => { variance_return_type(&coerced_data_types[0]) } - AggregateFunction::CovariancePop => { - covariance_return_type(&coerced_data_types[0]) - } AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } @@ -349,8 +342,7 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::CovariancePop - | AggregateFunction::Correlation + AggregateFunction::Correlation | AggregateFunction::RegrSlope | AggregateFunction::RegrIntercept | AggregateFunction::RegrCount diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 39726d7d0e62..57c0b6f4edc5 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -183,16 +183,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::CovariancePop => { - if !is_covariance_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } AggregateFunction::Stddev | AggregateFunction::StddevPop => { if !is_stddev_support_arg_type(&input_types[0]) { return plan_err!( diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 130b193996b6..1210e1529dbb 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -43,6 +43,14 @@ make_udaf_expr_and_func!( covar_samp_udaf ); +make_udaf_expr_and_func!( + CovariancePopulation, + covar_pop, + y x, + "Computes the population covariance.", + covar_pop_udaf +); + pub struct CovarianceSample { signature: Signature, aliases: Vec, @@ -120,6 +128,79 @@ impl AggregateUDFImpl for CovarianceSample { } } +pub struct CovariancePopulation { + signature: Signature, +} + +impl Debug for CovariancePopulation { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("CovariancePopulation") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for CovariancePopulation { + fn default() -> Self { + Self::new() + } +} + +impl CovariancePopulation { + pub fn new() -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CovariancePopulation { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "covar_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields( + &self, + name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new( + StatsType::Population, + )?)) + } +} + /// An accumulator to compute covariance /// The algorithm used is an online implementation and numerically stable. It is derived from the following paper /// for calculating variance: diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index d4e4d3a5f328..e76a43e39899 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -75,6 +75,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ first_last::first_value_udaf(), covariance::covar_samp_udaf(), + covariance::covar_pop_udaf(), ]; functions.into_iter().try_for_each(|udf| { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 36af875473be..145e7feadf8c 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -181,17 +181,6 @@ pub fn create_aggregate_expr( (AggregateFunction::VariancePop, true) => { return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } - (AggregateFunction::CovariancePop, false) => { - Arc::new(expressions::CovariancePop::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::CovariancePop, true) => { - return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available"); - } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( input_phy_exprs[0].clone(), name, diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 272f1d8be2b5..639d8a098c01 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -17,111 +17,17 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, compute::cast, datatypes::DataType, - datatypes::Field, }; use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use crate::aggregate::stats::StatsType; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; - -/// COVAR_POP aggregate expression -#[derive(Debug)] -pub struct CovariancePop { - name: String, - expr1: Arc, - expr2: Arc, -} - -impl CovariancePop { - /// Create a new COVAR_POP aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of covariance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for CovariancePop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CovarianceAccumulator::try_new( - StatsType::Population, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for CovariancePop { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} /// An accumulator to compute covariance /// The algrithm used is an online implementation and numerically stable. It is derived from the following paper @@ -319,281 +225,3 @@ impl Accumulator for CovarianceAccumulator { std::mem::size_of_val(self) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op2; - use arrow::{array::*, datatypes::*}; - - #[test] - fn covariance_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_f64_5() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.6022222222222223_f64) - ) - } - - #[test] - fn covariance_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, - ])); - let b = Arc::new(Float64Array::from(vec![ - 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, - ])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.7616666666666666_f64) - ) - } - - #[test] - fn covariance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32])); - let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32])); - generic_test_op2!( - a, - b, - DataType::UInt32, - DataType::UInt32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32])); - let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32])); - generic_test_op2!( - a, - b, - DataType::Float32, - DataType::Float32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_i32_with_nulls_1() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - None, - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(9), - Some(5), - Some(8), - Some(6), - None, - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_pop_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::Float64(None) - ) - } - - #[test] - fn covariance_pop_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![2_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0_f64) - ) - } - - #[test] - fn covariance_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64])); - let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 6.6_f64])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.7616666666666666)); - - Ok(()) - } - - #[test] - fn covariance_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![None])); - let d = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.6666666666666666)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 3efa965d1473..c16b609e2375 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,6 @@ pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; -pub use crate::aggregate::covariance::CovariancePop; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c057ab8acda7..311a0bf863ad 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -549,7 +549,7 @@ enum AggregateFunction { VARIANCE = 7; VARIANCE_POP = 8; // COVARIANCE = 9; - COVARIANCE_POP = 10; + // COVARIANCE_POP = 10; STDDEV = 11; STDDEV_POP = 12; CORRELATION = 13; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 994703c5fcfb..a1a141735881 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -430,7 +430,6 @@ impl serde::Serialize for AggregateFunction { Self::ArrayAgg => "ARRAY_AGG", Self::Variance => "VARIANCE", Self::VariancePop => "VARIANCE_POP", - Self::CovariancePop => "COVARIANCE_POP", Self::Stddev => "STDDEV", Self::StddevPop => "STDDEV_POP", Self::Correlation => "CORRELATION", @@ -477,7 +476,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG", "VARIANCE", "VARIANCE_POP", - "COVARIANCE_POP", "STDDEV", "STDDEV_POP", "CORRELATION", @@ -553,7 +551,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "VARIANCE" => Ok(AggregateFunction::Variance), "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), - "COVARIANCE_POP" => Ok(AggregateFunction::CovariancePop), "STDDEV" => Ok(AggregateFunction::Stddev), "STDDEV_POP" => Ok(AggregateFunction::StddevPop), "CORRELATION" => Ok(AggregateFunction::Correlation), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fc23a9ea05f7..706794e38070 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2835,7 +2835,7 @@ pub enum AggregateFunction { Variance = 7, VariancePop = 8, /// COVARIANCE = 9; - CovariancePop = 10, + /// COVARIANCE_POP = 10; Stddev = 11, StddevPop = 12, Correlation = 13, @@ -2881,7 +2881,6 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Variance => "VARIANCE", AggregateFunction::VariancePop => "VARIANCE_POP", - AggregateFunction::CovariancePop => "COVARIANCE_POP", AggregateFunction::Stddev => "STDDEV", AggregateFunction::StddevPop => "STDDEV_POP", AggregateFunction::Correlation => "CORRELATION", @@ -2924,7 +2923,6 @@ impl AggregateFunction { "ARRAY_AGG" => Some(Self::ArrayAgg), "VARIANCE" => Some(Self::Variance), "VARIANCE_POP" => Some(Self::VariancePop), - "COVARIANCE_POP" => Some(Self::CovariancePop), "STDDEV" => Some(Self::Stddev), "STDDEV_POP" => Some(Self::StddevPop), "CORRELATION" => Some(Self::Correlation), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 35d4c6409bc1..585bcad7f38c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -428,7 +428,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Variance => Self::Variance, protobuf::AggregateFunction::VariancePop => Self::VariancePop, - protobuf::AggregateFunction::CovariancePop => Self::CovariancePop, protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, protobuf::AggregateFunction::Correlation => Self::Correlation, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 80acd12e4e60..4c29d7551bc6 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -369,7 +369,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Variance => Self::Variance, AggregateFunction::VariancePop => Self::VariancePop, - AggregateFunction::CovariancePop => Self::CovariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, @@ -673,9 +672,6 @@ pub fn serialize_expr( AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, AggregateFunction::StddevPop => { protobuf::AggregateFunction::StddevPop diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3bc71f5f4c90..162a2f28e16b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,12 +25,12 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, Count, CovariancePop, CumeDist, DistinctArrayAgg, - DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, - IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, NegativeExpr, - NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, - RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, - VariancePop, WindowShift, + CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, + DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, IsNotNullExpr, + IsNullExpr, LastValue, Literal, Max, Median, Min, NegativeExpr, NotExpr, NthValue, + NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, + RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, VariancePop, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -291,8 +291,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Variance } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::VariancePop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::CovariancePop } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Stddev } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3800b672b5e2..819e20615685 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -30,7 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::functions_aggregate::covariance::covar_samp; +use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -616,6 +616,7 @@ async fn roundtrip_expr_api() -> Result<()> { array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), first_value(vec![lit(1)], false, None, None, None), covar_samp(lit(1.5), lit(2.2), false, None, None, None), + covar_pop(lit(1.5), lit(2.2), true, None, None, None), ]; // ensure expressions created with the expr api can be round tripped From 5b957152153ea3d78bb48a111e3b1f58d345308f Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Thu, 9 May 2024 23:30:01 -0500 Subject: [PATCH 2/2] add sqllogictest --- .../sqllogictest/test_files/aggregate.slt | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bc677b73fb94..40f78e7f4d24 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1812,6 +1812,186 @@ select avg(c1), arrow_typeof(avg(c1)) from t; statement ok drop table t; +# covariance_f64_1 +statement ok +create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_f64_2 +statement ok +create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_f64_4 +statement ok +create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# covariance_f64_5 +statement ok +create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.602222222222 Float64 + +statement ok +drop table t; + +# covariance_f64_6 +statement ok +create table t (c1 double, c2 double) as values (1.0, 4.0), (2.0, 5.0), (3.0, 6.0), (1.1, 4.4), (2.2, 5.5), (3.3, 6.6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.761666666667 Float64 + +statement ok +drop table t; + +# covariance_i32 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_u32 +statement ok +create table t (c1 int unsigned, c2 int unsigned) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_f32 +statement ok +create table t (c1 float, c2 float) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_1 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, null), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_2 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, 9), (2, 5), (null, 8), (3, 6), (null, null); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_3 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, 9), (2, 5), (null, 8), (3, 6), (null, null); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_i32_all_nulls +statement ok +create table t (c1 int, c2 int) as values (null, null), (null, null); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_pop_i32_all_nulls +statement ok +create table t (c1 int, c2 int) as values (null, null), (null, null); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_1_input +statement ok +create table t (c1 double, c2 double) as values (1, 2); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_pop_1_input +statement ok +create table t (c1 double, c2 double) as values (1, 2); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0 Float64 + +statement ok +drop table t; + # simple_mean query R select mean(c1) from test