Skip to content

feat: array_contains #6618

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 7 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,39 @@ query ?
select make_array(x, y) from foo2;
----
[1.0, 1]

# array_contains scalar function #1
query BBB rowsort
select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]);
----
true true true

# array_contains scalar function #2
query BB rowsort
select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 3]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 4]);
----
true true

# array_contains scalar function #3
query BBB rowsort
select array_contains(make_array(1, 2, 3), make_array(1, 2, 3, 4)), array_contains([1, 2, 3], [1, 1, 4]), array_contains([1, 2, 3], [2, 1, 3, 4]);
----
false false false

# array_contains scalar function #4
query BB rowsort
select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 5]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 5]);
----
false false

# array_contains scalar function #5
query BB rowsort
select array_contains([true, true, false, true, false], [true, false, false]), array_contains([true, false, true], [true, true]);
----
true true

# array_contains scalar function #6
query BB rowsort
select array_contains(make_array(true, true, true), make_array(false, false)), array_contains([false, false, false], [true, true]);
----
false false
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ pub enum BuiltinScalarFunction {
ArrayAppend,
/// array_concat
ArrayConcat,
/// array_contains
ArrayContains,
/// array_dims
ArrayDims,
/// array_fill
Expand Down Expand Up @@ -319,6 +321,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
BuiltinScalarFunction::ArrayContains => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayFill => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
Expand Down Expand Up @@ -460,6 +463,7 @@ impl BuiltinScalarFunction {
"The {self} function can only accept fixed size list as the args."
))),
},
BuiltinScalarFunction::ArrayContains => Ok(Boolean),
BuiltinScalarFunction::ArrayDims => Ok(UInt8),
BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new(
"item",
Expand Down Expand Up @@ -741,6 +745,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayContains => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayLength => {
Expand Down Expand Up @@ -1166,6 +1171,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
// array functions
BuiltinScalarFunction::ArrayAppend => &["array_append"],
BuiltinScalarFunction::ArrayConcat => &["array_concat"],
BuiltinScalarFunction::ArrayContains => &["array_contains"],
BuiltinScalarFunction::ArrayDims => &["array_dims"],
BuiltinScalarFunction::ArrayFill => &["array_fill"],
BuiltinScalarFunction::ArrayLength => &["array_length"],
Expand Down
7 changes: 7 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ scalar_expr!(
"appends an element to the end of an array."
);
nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays.");
scalar_expr!(
ArrayContains,
array_contains,
first_array second_array,
"returns true, if each element of the second array appe
aring in the first array, otherwise false."
);
scalar_expr!(
ArrayDims,
array_dims,
Expand Down
126 changes: 124 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use datafusion_common::cast::as_list_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use itertools::Itertools;
use std::sync::Arc;

macro_rules! downcast_vec {
Expand Down Expand Up @@ -1070,6 +1071,70 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result<ColumnarValue> {
]))))
}

macro_rules! contains {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this comment -- are you saying that it would be better to use the in_list kernel rather than flattening it?

($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let first_array = downcast_arg!($FIRST_ARRAY, $ARRAY_TYPE);
let second_array = downcast_arg!($SECOND_ARRAY, $ARRAY_TYPE);
let mut res = true;
for x in second_array.values().iter().dedup() {
if !first_array.values().contains(x) {
res = false;
}
}

res
}};
}

/// Array_contains SQL function
pub fn array_contains(args: &[ArrayRef]) -> Result<ArrayRef> {
fn concat_inner_lists(arg: ArrayRef) -> Result<ArrayRef> {
match arg.data_type() {
DataType::List(field) => match field.data_type() {
DataType::List(..) => {
concat_inner_lists(array_concat(&[as_list_array(&arg)?
.values()
.clone()])?)
}
_ => Ok(as_list_array(&arg)?.values().clone()),
},
data_type => Err(DataFusionError::NotImplemented(format!(
"Array is not type '{data_type:?}'."
))),
}
}

let concat_first_array = concat_inner_lists(args[0].clone())?.clone();
let concat_second_array = concat_inner_lists(args[1].clone())?.clone();

let res = match (concat_first_array.data_type(), concat_second_array.data_type()) {
(DataType::Utf8, DataType::Utf8) => contains!(concat_first_array, concat_second_array, StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => contains!(concat_first_array, concat_second_array, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => {
let first_array = downcast_arg!(concat_first_array, BooleanArray);
let second_array = downcast_arg!(concat_second_array, BooleanArray);
compute::bool_or(first_array) == compute::bool_or(second_array)
}
(DataType::Float32, DataType::Float32) => contains!(concat_first_array, concat_second_array, Float32Array),
(DataType::Float64, DataType::Float64) => contains!(concat_first_array, concat_second_array, Float64Array),
(DataType::Int8, DataType::Int8) => contains!(concat_first_array, concat_second_array, Int8Array),
(DataType::Int16, DataType::Int16) => contains!(concat_first_array, concat_second_array, Int16Array),
(DataType::Int32, DataType::Int32) => contains!(concat_first_array, concat_second_array, Int32Array),
(DataType::Int64, DataType::Int64) => contains!(concat_first_array, concat_second_array, Int64Array),
(DataType::UInt8, DataType::UInt8) => contains!(concat_first_array, concat_second_array, UInt8Array),
(DataType::UInt16, DataType::UInt16) => contains!(concat_first_array, concat_second_array, UInt16Array),
(DataType::UInt32, DataType::UInt32) => contains!(concat_first_array, concat_second_array, UInt32Array),
(DataType::UInt64, DataType::UInt64) => contains!(concat_first_array, concat_second_array, UInt64Array),
(first_array_data_type, second_array_data_type) => {
return Err(DataFusionError::NotImplemented(format!(
"Array_contains is not implemented for types '{first_array_data_type:?}' and '{second_array_data_type:?}'."
)))
}
};

Ok(Arc::new(BooleanArray::from(vec![res])))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1588,7 +1653,7 @@ mod tests {

#[test]
fn test_array_ndims() {
// array_ndims([1, 2]) = 1
// array_ndims([1, 2, 3, 4]) = 1
let list_array = return_array();

let array = array_ndims(&[list_array])
Expand All @@ -1602,7 +1667,7 @@ mod tests {

#[test]
fn test_nested_array_ndims() {
// array_ndims([[1, 2], [3, 4]]) = 2
// array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2
let list_array = return_nested_array();

let array = array_ndims(&[list_array])
Expand All @@ -1614,6 +1679,63 @@ mod tests {
assert_eq!(result, &UInt8Array::from(vec![2]));
}

#[test]
fn test_array_contains() {
// array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 3)) = t
let first_array = return_array().into_array(1);
let second_array = array_append(&[
first_array.clone(),
Arc::new(Int64Array::from(vec![Some(3)])),
])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![true]));

// array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 5)) = f
let second_array = array_append(&[
first_array.clone(),
Arc::new(Int64Array::from(vec![Some(5)])),
])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![false]));
}

#[test]
fn test_nested_array_contains() {
// array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 3)) = t
let first_array = return_nested_array().into_array(1);
let array = return_array().into_array(1);
let second_array =
array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(3)]))])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![true]));

// array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 9)) = f
let second_array =
array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(9)]))])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![false]));
}

fn return_array() -> ColumnarValue {
let args = [
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayConcat => {
Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args))
}
BuiltinScalarFunction::ArrayContains => {
Arc::new(|args| make_scalar_function(array_expressions::array_contains)(args))
}
BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims),
BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill),
BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ enum ScalarFunction {
ArrayToString = 97;
Cardinality = 98;
TrimArray = 99;
ArrayContains = 100;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 11 additions & 6 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ use datafusion_common::{
};
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{
abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill,
array_length, array_ndims, array_position, array_positions, array_prepend,
array_remove, array_replace, array_to_string, ascii, asin, asinh, atan, atan2, atanh,
bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, date_trunc, degrees,
digest, exp,
abs, acos, acosh, array, array_append, array_concat, array_contains, array_dims,
array_fill, array_length, array_ndims, array_position, array_positions,
array_prepend, array_remove, array_replace, array_to_string, ascii, asin, asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part,
date_trunc, degrees, digest, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
Expand Down Expand Up @@ -450,6 +450,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ToTimestamp => Self::ToTimestamp,
ScalarFunction::ArrayAppend => Self::ArrayAppend,
ScalarFunction::ArrayConcat => Self::ArrayConcat,
ScalarFunction::ArrayContains => Self::ArrayContains,
ScalarFunction::ArrayDims => Self::ArrayDims,
ScalarFunction::ArrayFill => Self::ArrayFill,
ScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down Expand Up @@ -1192,6 +1193,10 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::ArrayContains => Ok(array_contains(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayFill => Ok(array_fill(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp,
BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend,
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
BuiltinScalarFunction::ArrayContains => Self::ArrayContains,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
BuiltinScalarFunction::ArrayFill => Self::ArrayFill,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down
Loading