diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 43cb0c3e0a50..e05cdf00dcda 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -319,4 +319,19 @@ mod tests { .unwrap(); assert_eq!(return_type, DataType::Date32); } + + #[test] + fn test_coalesce_return_types_dictionary() { + let coalesce = BuiltinScalarFunction::Coalesce; + let return_type = coalesce + .return_type(&[ + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Utf8, + ]) + .unwrap(); + assert_eq!( + return_type, + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + ); + } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 34b607d0884d..37eeb7d464b8 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -314,8 +314,13 @@ fn coerced_from<'a>( // match Dictionary first match (type_into, type_from) { // coerced dictionary first - (cur_type, Dictionary(_, value_type)) | (Dictionary(_, value_type), cur_type) - if coerced_from(cur_type, value_type).is_some() => + (_, Dictionary(_, value_type)) + if coerced_from(type_into, value_type).is_some() => + { + Some(type_into.clone()) + } + (Dictionary(_, value_type), _) + if coerced_from(value_type, type_from).is_some() => { Some(type_into.clone()) } @@ -624,4 +629,20 @@ mod tests { Ok(()) } + + #[test] + fn test_coerced_from_dictionary() { + let type_into = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); + let type_from = DataType::Int64; + assert_eq!(coerced_from(&type_into, &type_from), None); + + let type_from = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); + let type_into = DataType::Int64; + assert_eq!( + coerced_from(&type_into, &type_from), + Some(type_into.clone()) + ); + } }