Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FixedSizeList type in arrow_cast, hashing #8344

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array};
use arrow_buffer::i256;

use crate::cast::{
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array,
as_primitive_array, as_string_array, as_struct_array,
as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
as_large_list_array, as_list_array, as_primitive_array, as_string_array,
as_struct_array,
};
use crate::error::{DataFusionError, Result, _internal_err};

Expand Down Expand Up @@ -267,6 +268,38 @@ where
Ok(())
}

fn hash_fixed_list_array(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any test coverage for this new code -- e.g. either unit tests for hashing or a higher level test like GROUP BY <FixedListArray>

Can you either ensure this code is tested somehow, or else perhaps move the hash support to a different PR so we can merge the arrow_cast support ?

array: &FixedSizeListArray,
random_state: &RandomState,
hashes_buffer: &mut [u64],
) -> Result<()> {
let values = array.values().clone();
let value_len = array.value_length();
let offset_size = value_len as usize / array.len();
let nulls = array.nulls();
let mut values_hashes = vec![0u64; values.len()];
create_hashes(&[values], random_state, &mut values_hashes)?;
if let Some(nulls) = nulls {
for i in 0..array.len() {
if nulls.is_valid(i) {
let hash = &mut hashes_buffer[i];
for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size]
{
*hash = combine_hashes(*hash, *values_hash);
}
}
}
} else {
for i in 0..array.len() {
let hash = &mut hashes_buffer[i];
for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] {
*hash = combine_hashes(*hash, *values_hash);
}
}
}
Ok(())
}

/// Test version of `create_hashes` that produces the same value for
/// all hashes (to test collisions)
///
Expand Down Expand Up @@ -366,6 +399,10 @@ pub fn create_hashes<'a>(
let array = as_large_list_array(array)?;
hash_list_array(array, random_state, hashes_buffer)?;
}
DataType::FixedSizeList(_,_) => {
let array = as_fixed_size_list_array(array)?;
hash_fixed_list_array(array, random_state, hashes_buffer)?;
}
_ => {
// This is internal because we should have caught this before.
return _internal_err!(
Expand Down
24 changes: 20 additions & 4 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ use crate::cast::{
};
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
use crate::hash_utils::create_hashes;
use crate::utils::{array_into_large_list_array, array_into_list_array};

use crate::utils::{
array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array,
};
use arrow::compute::kernels::numeric::*;
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{
Expand Down Expand Up @@ -2223,9 +2224,11 @@ impl ScalarValue {
let list_array = as_fixed_size_list_array(array)?;
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(array_into_list_array(nested_array));
let list_size = nested_array.len();
let arr =
Arc::new(array_into_fixed_size_list_array(nested_array, list_size));

ScalarValue::List(arr)
ScalarValue::FixedSizeList(arr)
}
DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?,
DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?,
Expand Down Expand Up @@ -2971,6 +2974,19 @@ impl TryFrom<&DataType> for ScalarValue {
.to_owned()
.into(),
),
// `ScalaValue::FixedSizeList` contains single element `FixedSizeList`.
DataType::FixedSizeList(field, _) => ScalarValue::FixedSizeList(
new_null_array(
&DataType::FixedSizeList(
Arc::new(Field::new("item", field.data_type().clone(), true)),
1,
),
1,
)
.as_fixed_size_list()
.to_owned()
.into(),
),
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
DataType::Null => ScalarValue::Null,
_ => {
Expand Down
17 changes: 16 additions & 1 deletion datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use arrow::compute;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions};
use arrow_array::{
Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions,
};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
Expand Down Expand Up @@ -368,6 +370,19 @@ pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray {
)
}

pub fn array_into_fixed_size_list_array(
arr: ArrayRef,
list_size: usize,
) -> FixedSizeListArray {
let list_size = list_size as i32;
FixedSizeListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
list_size,
arr,
None,
)
}

/// Wrap arrays into a single element `ListArray`.
///
/// Example:
Expand Down
17 changes: 17 additions & 0 deletions datafusion/sql/src/expr/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl<'a> Parser<'a> {
Token::Dictionary => self.parse_dictionary(),
Token::List => self.parse_list(),
Token::LargeList => self.parse_large_list(),
Token::FixedSizeList => self.parse_fixed_size_list(),
tok => Err(make_error(
self.val,
&format!("finding next type, got unexpected '{tok}'"),
Expand Down Expand Up @@ -177,6 +178,19 @@ impl<'a> Parser<'a> {
))))
}

/// Parses the FixedSizeList type
fn parse_fixed_size_list(&mut self) -> Result<DataType> {
self.expect_token(Token::LParen)?;
let length = self.parse_i32("FixedSizeList")?;
self.expect_token(Token::Comma)?;
let data_type = self.parse_next_type()?;
self.expect_token(Token::RParen)?;
Ok(DataType::FixedSizeList(
Arc::new(Field::new("item", data_type, true)),
length,
))
}

/// Parses the next timeunit
fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
match self.next_token()? {
Expand Down Expand Up @@ -508,6 +522,7 @@ impl<'a> Tokenizer<'a> {

"List" => Token::List,
"LargeList" => Token::LargeList,
"FixedSizeList" => Token::FixedSizeList,

"Second" => Token::TimeUnit(TimeUnit::Second),
"Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
Expand Down Expand Up @@ -598,6 +613,7 @@ enum Token {
DoubleQuotedString(String),
List,
LargeList,
FixedSizeList,
}

impl Display for Token {
Expand All @@ -606,6 +622,7 @@ impl Display for Token {
Token::SimpleType(t) => write!(f, "{t}"),
Token::List => write!(f, "List"),
Token::LargeList => write!(f, "LargeList"),
Token::FixedSizeList => write!(f, "FixedSizeList"),
Token::Timestamp => write!(f, "Timestamp"),
Token::Time32 => write!(f, "Time32"),
Token::Time64 => write!(f, "Time64"),
Expand Down
33 changes: 32 additions & 1 deletion datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,35 @@ LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, di
query T
select arrow_typeof(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))'));
----
LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

## FixedSizeList

query ?
select arrow_cast(null, 'FixedSizeList(1, Int64)');
----
NULL

#TODO: arrow-rs doesn't support it yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be casting [1] (not 1)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the List supports casting from UTF8 to List with a single size. Therefore, I think FixedSizeList should also support it.

select arrow_cast('1', 'LargeList(Int64)');
----
[1]

#query ?
#select arrow_cast('1', 'FixedSizeList(1, Int64)');
#----
#[1]

query error DataFusion error: Optimizer rule 'simplify_expressions' failed
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)');

query ?
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)');
----
[1, 2, 3]

query T
select arrow_typeof(arrow_cast(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 'FixedSizeList(3, Int64)'));
----
FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3)

query ?
select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)');
----
[1, 2, 3]