Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression by reverting Materialize dictionaries in group keys #8740

Merged
merged 4 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions datafusion/core/tests/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> {
assert_eq!(min_limit, resulting_limit);

let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
let month = match s {
ScalarValue::Utf8(Some(month)) => month,
s => panic!("Expected month as Utf8 found {s:?}"),
let month = match extract_as_utf(&s) {
Some(month) => month,
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
};

let sql_on_partition_boundary = format!(
Expand All @@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> {
Ok(())
}

fn extract_as_utf(v: &ScalarValue) -> Option<String> {
if let ScalarValue::Dictionary(_, v) = v {
if let ScalarValue::Utf8(v) = v.as_ref() {
return v.clone();
}
}
None
}

#[tokio::test]
async fn csv_filter_with_file_col() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
27 changes: 23 additions & 4 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
use arrow_array::{Array, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_physical_expr::EmitTo;
use hashbrown::raw::RawTable;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
/// The output schema
schema: SchemaRef,

/// Converter for the group values
row_converter: RowConverter,

Expand Down Expand Up @@ -75,6 +79,7 @@ impl GroupValuesRows {
let map = RawTable::with_capacity(0);

Ok(Self {
schema,
row_converter,
map,
map_size: 0,
Expand Down Expand Up @@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows {
.take()
.expect("Can not emit from empty rows");

let output = match emit_to {
let mut output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
group_values.clear();
Expand Down Expand Up @@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows {
}
};

// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
}

self.group_values = Some(group_values);
Ok(output)
}
Expand Down
35 changes: 3 additions & 32 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use crate::{
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -254,9 +253,6 @@ pub struct AggregateExec {
limit: Option<usize>,
/// Input plan, could be a partial aggregate or the input to the aggregate
pub input: Arc<dyn ExecutionPlan>,
/// Original aggregation schema, could be different from `schema` before dictionary group
/// keys get materialized
original_schema: SchemaRef,
/// Schema after the aggregate is applied
schema: SchemaRef,
/// Input schema before any aggregation is applied. For partial aggregate this will be the
Expand Down Expand Up @@ -287,19 +283,15 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let original_schema = create_schema(
let schema = create_schema(
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.contains_null(),
mode,
)?;

let schema = Arc::new(materialize_dict_group_keys(
&original_schema,
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
let schema = Arc::new(schema);
AggregateExec::try_new_with_schema(
mode,
group_by,
Expand All @@ -308,7 +300,6 @@ impl AggregateExec {
input,
input_schema,
schema,
original_schema,
)
}

Expand All @@ -329,7 +320,6 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
schema: SchemaRef,
original_schema: SchemaRef,
) -> Result<Self> {
let input_eq_properties = input.equivalence_properties();
// Get GROUP BY expressions:
Expand Down Expand Up @@ -382,7 +372,6 @@ impl AggregateExec {
aggr_expr,
filter_expr,
input,
original_schema,
schema,
input_schema,
projection_mapping,
Expand Down Expand Up @@ -693,7 +682,7 @@ impl ExecutionPlan for AggregateExec {
children[0].clone(),
self.input_schema.clone(),
self.schema.clone(),
self.original_schema.clone(),
//self.original_schema.clone(),
)?;
me.limit = self.limit;
Ok(Arc::new(me))
Expand Down Expand Up @@ -800,24 +789,6 @@ fn create_schema(
Ok(Schema::new(fields))
}

/// returns schema with dictionary group keys materialized as their value types
/// The actual convertion happens in `RowConverter` and we don't do unnecessary
/// conversion back into dictionaries
fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema {
let fields = schema
.fields
.iter()
.enumerate()
.map(|(i, field)| match field.data_type() {
DataType::Dictionary(_, value_data_type) if i < group_count => {
Field::new(field.name(), *value_data_type.clone(), field.is_nullable())
}
_ => Field::clone(field),
})
.collect::<Vec<_>>();
Schema::new(fields)
}

fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
let group_fields = schema.fields()[0..group_count].to_vec();
Arc::new(Schema::new(group_fields))
Expand Down
4 changes: 1 addition & 3 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ impl GroupedHashAggregateStream {
.map(create_group_accumulator)
.collect::<Result<_>>()?;

// we need to use original schema so RowConverter in group_values below
// will do the proper coversion of dictionaries into value types
let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len());
let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
let spill_expr = group_schema
.fields
.into_iter()
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2469,11 +2469,11 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
query T
select arrow_typeof(x_dict) from value_dict group by x_dict;
----
Int32
Int32
Int32
Int32
Int32
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)

statement ok
drop table value
Expand Down
81 changes: 80 additions & 1 deletion datafusion/sqllogictest/test_files/dictionary.slt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ order by date_bin('30 minutes', time) DESC

# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738
# This query should work correctly
query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Utf8 got Dictionary\(Int32, Utf8\)
query P?TT rowsort
SELECT
"data"."timestamp" as "time",
"data"."tag_id",
Expand Down Expand Up @@ -201,3 +201,82 @@ ORDER BY
"time",
"data"."tag_id"
;
----
2023-12-20T00:00:00 1000 f1 32.0
2023-12-20T00:00:00 1000 f2 foo
2023-12-20T00:10:00 1000 f1 32.0
2023-12-20T00:10:00 1000 f2 foo
2023-12-20T00:20:00 1000 f1 32.0
2023-12-20T00:20:00 1000 f2 foo
2023-12-20T00:30:00 1000 f1 32.0
2023-12-20T00:30:00 1000 f2 foo
2023-12-20T00:40:00 1000 f1 32.0
2023-12-20T00:40:00 1000 f2 foo
2023-12-20T00:50:00 1000 f1 32.0
2023-12-20T00:50:00 1000 f2 foo
2023-12-20T01:00:00 1000 f1 32.0
2023-12-20T01:00:00 1000 f2 foo
2023-12-20T01:10:00 1000 f1 32.0
2023-12-20T01:10:00 1000 f2 foo
2023-12-20T01:20:00 1000 f1 32.0
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo


# deterministic sort (so we can avoid rowsort)
query P?TT
SELECT
"data"."timestamp" as "time",
"data"."tag_id",
"data"."field",
"data"."value"
FROM (
(
SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value"
FROM "m2"
WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00'
AND "m2"."f5" IS NOT NULL
AND "m2"."type" IN ('active')
AND "m2"."tag_id" IN ('1000')
) UNION (
SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value"
FROM "m1"
WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00'
AND "m1"."f1" IS NOT NULL
AND "m1"."tag_id" IN ('1000')
) UNION (
SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value"
FROM "m1"
WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00'
AND "m1"."f2" IS NOT NULL
AND "m1"."tag_id" IN ('1000')
)
) as "data"
ORDER BY
"time",
"data"."tag_id",
"data"."field",
"data"."value"
;
----
2023-12-20T00:00:00 1000 f1 32.0
2023-12-20T00:00:00 1000 f2 foo
2023-12-20T00:10:00 1000 f1 32.0
2023-12-20T00:10:00 1000 f2 foo
2023-12-20T00:20:00 1000 f1 32.0
2023-12-20T00:20:00 1000 f2 foo
2023-12-20T00:30:00 1000 f1 32.0
2023-12-20T00:30:00 1000 f2 foo
2023-12-20T00:40:00 1000 f1 32.0
2023-12-20T00:40:00 1000 f2 foo
2023-12-20T00:50:00 1000 f1 32.0
2023-12-20T00:50:00 1000 f2 foo
2023-12-20T01:00:00 1000 f1 32.0
2023-12-20T01:00:00 1000 f2 foo
2023-12-20T01:10:00 1000 f1 32.0
2023-12-20T01:10:00 1000 f2 foo
2023-12-20T01:20:00 1000 f1 32.0
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo