Skip to content

Commit

Permalink
Implement GetIndexedField for map-typed columns
Browse files Browse the repository at this point in the history
  • Loading branch information
swgillespie committed Oct 14, 2023
1 parent a86ee16 commit 545548d
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ arrow-array = { version = "47.0.0", default-features = false, features = ["chron
arrow-buffer = { version = "47.0.0", default-features = false }
arrow-flight = { version = "47.0.0", features = ["flight-sql-experimental"] }
arrow-schema = { version = "47.0.0", default-features = false }
arrow-ord = { version = "47.0.0", default-features = false }
parquet = { version = "47.0.0", features = ["arrow", "async", "object_store"] }
sqlparser = { version = "0.38.0", features = ["visitor"] }
chrono = { version = "0.4.31", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added datafusion/core/tests/data/parquet_map.parquet
Binary file not shown.
15 changes: 14 additions & 1 deletion datafusion/expr/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ impl GetFieldAccessSchema {
match self {
Self::NamedStructField{ name } => {
match (data_type, name) {
(DataType::Map(fields, _), _) => {
match fields.data_type() {
DataType::Struct(fields) if fields.len() == 2 => {
// Arrow's MapArray is essentially a ListArray of structs with two columns. They are
// often named "key", and "value", but we don't require any specific naming here;
// instead, we assume that the second columnis the "value" column both here and in
// execution.
let value_field = fields.get(1).expect("fields should have exactly two members");
Ok(Field::new("map", value_field.data_type().clone(), true))
},
_ => plan_err!("Map fields must contain a Struct with exactly 2 fields"),
}
}
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
plan_err!(
Expand All @@ -58,7 +71,7 @@ impl GetFieldAccessSchema {
(DataType::Struct(_), _) => plan_err!(
"Only utf8 strings are valid as an indexed field in a struct"
),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"),
}
}
Self::ListIndex{ key_dt } => {
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"]
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
base64 = { version = "0.21", optional = true }
blake2 = { version = "^0.10.2", optional = true }
Expand Down
15 changes: 13 additions & 2 deletions datafusion/physical-expr/src/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
//! get field of a `ListArray`

use crate::PhysicalExpr;
use arrow::array::Array;
use datafusion_common::exec_err;

use crate::array_expressions::{array_element, array_slice};
use crate::physical_expr::down_cast_any_ref;
use arrow::{
array::{Array, Scalar, StringArray},
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue};
use datafusion_common::{
cast::{as_map_array, as_struct_array},
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
Expand Down Expand Up @@ -183,6 +186,14 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let array = self.arg.evaluate(batch)?.into_array(batch.num_rows());
match &self.field {
GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar = Scalar::new(StringArray::from(vec![k.clone()]));
let keys = arrow_ord::cmp::eq(&key_scalar, map_array.keys())?;
let entries = arrow::compute::filter(map_array.entries(), &keys)?;
let entries_struct_array = as_struct_array(entries.as_ref())?;
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
Expand Down
46 changes: 46 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

statement ok
CREATE EXTERNAL TABLE data
STORED AS PARQUET
LOCATION '../core/tests/data/parquet_map.parquet';

query I
SELECT SUM(ints['bytes']) FROM data;
----
5636785

query I
SELECT SUM(ints['bytes']) FROM data WHERE strings['method'] == 'GET';
----
649668

query TI
SELECT strings['method'] AS method, COUNT(*) as count FROM data GROUP BY method ORDER BY count DESC;
----
POST 41
HEAD 33
PATCH 30
OPTION 29
GET 27
PUT 25
DELETE 24

query T
SELECT strings['not_found'] FROM data LIMIT 1;
----

0 comments on commit 545548d

Please sign in to comment.