diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 58dc8f40b577..de522aba4e4f 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -29,8 +29,10 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_array::cast::AsArray; use arrow_array::{ - Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions, + Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, + RecordBatchOptions, }; use arrow_schema::DataType; use sqlparser::ast::Ident; @@ -440,6 +442,11 @@ pub fn arrays_into_list_array( )) } +/// Helper function to convert a ListArray into a vector of ArrayRefs. +pub fn list_to_arrays(a: ArrayRef) -> Vec { + a.as_list::().iter().flatten().collect::>() +} + /// Get the base type of a data type. /// /// Example diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index e218b501dcf1..b6068fdff0d5 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -15,17 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::make_array::make_array; +use std::any::Any; +use std::collections::VecDeque; +use std::sync::Arc; + use arrow::array::ArrayData; -use arrow_array::{Array, ArrayRef, MapArray, StructArray}; +use arrow_array::{Array, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_schema::{DataType, Field, SchemaBuilder}; + use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::collections::VecDeque; -use std::sync::Arc; + +use crate::make_array::make_array; /// Returns a map created from a key list and a value list pub fn map(keys: Vec, values: Vec) -> Expr { @@ -56,11 +59,11 @@ fn make_map_batch(args: &[ColumnarValue]) -> datafusion_common::Result Ok(array.value(0)), _ => exec_err!("Expected array, got {:?}", value), }, - ColumnarValue::Array(array) => exec_err!("Expected scalar, got {:?}", array), + ColumnarValue::Array(array) => Ok(array.to_owned()), } } @@ -81,6 +84,7 @@ fn make_map_batch_internal( keys: ArrayRef, values: ArrayRef, can_evaluate_to_const: bool, + data_type: DataType, ) -> datafusion_common::Result { if keys.null_count() > 0 { return exec_err!("map key cannot be null"); @@ -90,6 +94,14 @@ fn make_map_batch_internal( return exec_err!("map requires key and value lists to have the same length"); } + if !can_evaluate_to_const { + return if let DataType::LargeList(..) = data_type { + make_map_array_internal::(keys, values) + } else { + make_map_array_internal::(keys, values) + }; + } + let key_field = Arc::new(Field::new("key", keys.data_type().clone(), false)); let value_field = Arc::new(Field::new("value", values.data_type().clone(), true)); let mut entry_struct_buffer: VecDeque<(Arc, ArrayRef)> = VecDeque::new(); @@ -190,7 +202,6 @@ impl ScalarUDFImpl for MapFunc { make_map_batch(args) } } - fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), @@ -202,3 +213,115 @@ fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType ), } } + +/// Helper function to create MapArray from array of values to support arrays for Map scalar function +/// +/// ``` text +/// Format of input KEYS and VALUES column +/// keys values +/// +---------------------+ +---------------------+ +/// | +-----------------+ | | +-----------------+ | +/// | | [k11, k12, k13] | | | | [v11, v12, v13] | | +/// | +-----------------+ | | +-----------------+ | +/// | | | | +/// | +-----------------+ | | +-----------------+ | +/// | | [k21, k22, k23] | | | | [v21, v22, v23] | | +/// | +-----------------+ | | +-----------------+ | +/// | | | | +/// | +-----------------+ | | +-----------------+ | +/// | |[k31, k32, k33] | | | |[v31, v32, v33] | | +/// | +-----------------+ | | +-----------------+ | +/// +---------------------+ +---------------------+ +/// ``` +/// Flattened keys and values array to user create `StructArray`, +/// which serves as inner child for `MapArray` +/// +/// ``` text +/// Flattened Flattened +/// Keys Values +/// +-----------+ +-----------+ +/// | +-------+ | | +-------+ | +/// | | k11 | | | | v11 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k12 | | | | v12 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k13 | | | | v13 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k21 | | | | v21 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k22 | | | | v22 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k23 | | | | v23 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k31 | | | | v31 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k32 | | | | v32 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k33 | | | | v33 | | +/// | +-------+ | | +-------+ | +/// +-----------+ +-----------+ +/// ```text + +fn make_map_array_internal( + keys: ArrayRef, + values: ArrayRef, +) -> datafusion_common::Result { + let mut offset_buffer = vec![O::zero()]; + let mut running_offset = O::zero(); + + let keys = datafusion_common::utils::list_to_arrays::(keys); + let values = datafusion_common::utils::list_to_arrays::(values); + + let mut key_array_vec = vec![]; + let mut value_array_vec = vec![]; + for (k, v) in keys.iter().zip(values.iter()) { + running_offset = running_offset.add(O::usize_as(k.len())); + offset_buffer.push(running_offset); + key_array_vec.push(k.as_ref()); + value_array_vec.push(v.as_ref()); + } + + // concatenate all the arrays + let flattened_keys = arrow::compute::concat(key_array_vec.as_ref())?; + if flattened_keys.null_count() > 0 { + return exec_err!("keys cannot be null"); + } + let flattened_values = arrow::compute::concat(value_array_vec.as_ref())?; + + let fields = vec![ + Arc::new(Field::new("key", flattened_keys.data_type().clone(), false)), + Arc::new(Field::new( + "value", + flattened_values.data_type().clone(), + true, + )), + ]; + + let struct_data = ArrayData::builder(DataType::Struct(fields.into())) + .len(flattened_keys.len()) + .add_child_data(flattened_keys.to_data()) + .add_child_data(flattened_values.to_data()) + .build()?; + + let map_data = ArrayData::builder(DataType::Map( + Arc::new(Field::new( + "entries", + struct_data.data_type().clone(), + false, + )), + false, + )) + .len(keys.len()) + .add_child_data(struct_data) + .add_buffer(Buffer::from_slice_ref(offset_buffer.as_slice())) + .build()?; + Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data)))) +} diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index eb350c22bb5d..0dc37c68bca4 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -199,25 +199,50 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), a statement ok create table t as values -('a', 1, 'k1', 10, ['k1', 'k2'], [1, 2]), -('b', 2, 'k3', 30, ['k3'], [3]), -('d', 4, 'k5', 50, ['k5'], [5]); +('a', 1, 'k1', 10, ['k1', 'k2'], [1, 2], 'POST', [[1,2,3]], ['a']), +('b', 2, 'k3', 30, ['k3'], [3], 'PUT', [[4]], ['b']), +('d', 4, 'k5', 50, ['k5'], [5], null, [[1,2]], ['c']); -query error +query ? SELECT make_map(column1, column2, column3, column4) FROM t; -# TODO: support array value -# ---- -# {a: 1, k1: 10} -# {b: 2, k3: 30} -# {d: 4, k5: 50} +---- +{a: 1, k1: 10} +{b: 2, k3: 30} +{d: 4, k5: 50} -query error +query ? SELECT map(column5, column6) FROM t; -# TODO: support array value -# ---- -# {k1:1, k2:2} -# {k3: 3} -# {k5: 5} +---- +{k1: 1, k2: 2} +{k3: 3} +{k5: 5} + +query ? +SELECT map(column8, column9) FROM t; +---- +{[1, 2, 3]: a} +{[4]: b} +{[1, 2]: c} + +query error +SELECT map(column6, column7) FROM t; + +query ? +select Map {column6: column7} from t; +---- +{[1, 2]: POST} +{[3]: PUT} +{[5]: } + +query ? +select Map {column8: column7} from t; +---- +{[[1, 2, 3]]: POST} +{[[4]]: PUT} +{[[1, 2]]: } + +query error +select Map {column7: column8} from t; query ? SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count from t;