Skip to content

Introduce array_except function #8135

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ pub enum BuiltinScalarFunction {
ArrayIntersect,
/// array_union
ArrayUnion,
/// array_except
ArrayExcept,
/// cardinality
Cardinality,
/// construct an array from columns
Expand Down Expand Up @@ -394,6 +396,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
BuiltinScalarFunction::ArrayNdims => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable,
Expand Down Expand Up @@ -601,6 +604,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Range => {
Ok(List(Arc::new(Field::new("item", Int64, true))))
}
BuiltinScalarFunction::ArrayExcept => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
Expand Down Expand Up @@ -887,6 +891,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayHasAll
| BuiltinScalarFunction::ArrayHasAny
Expand Down Expand Up @@ -1521,6 +1526,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
"list_element",
"list_extract",
],
BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"],
BuiltinScalarFunction::Flatten => &["flatten"],
BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"],
BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"],
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,12 @@ scalar_expr!(
array element,
"extracts the element with the index n from the array."
);
scalar_expr!(
ArrayExcept,
array_except,
first_array second_array,
"Returns an array of the elements that appear in the first array but not in the second."
);
scalar_expr!(
ArrayLength,
array_length,
Expand Down
80 changes: 79 additions & 1 deletion datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Array expressions

use std::any::type_name;
use std::collections::HashSet;
use std::sync::Arc;

use arrow::array::*;
Expand All @@ -38,7 +39,6 @@ use datafusion_common::{
};

use itertools::Itertools;
use std::collections::HashSet;

macro_rules! downcast_arg {
($ARG:expr, $ARRAY_TYPE:ident) => {{
Expand Down Expand Up @@ -629,6 +629,84 @@ pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
define_array_slice(list_array, key, key, true)
}

fn general_except<OffsetSize: OffsetSizeTrait>(
l: &GenericListArray<OffsetSize>,
r: &GenericListArray<OffsetSize>,
field: &FieldRef,
) -> Result<GenericListArray<OffsetSize>> {
let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;

let l_values = l.values().to_owned();
let r_values = r.values().to_owned();
let l_values = converter.convert_columns(&[l_values])?;
let r_values = converter.convert_columns(&[r_values])?;

let mut offsets = Vec::<OffsetSize>::with_capacity(l.len() + 1);
offsets.push(OffsetSize::usize_as(0));

let mut rows = Vec::with_capacity(l_values.num_rows());
let mut dedup = HashSet::new();

for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
for i in r_slice {
let right_row = r_values.row(i);
dedup.insert(right_row);
}
for i in l_slice {
let left_row = l_values.row(i);
if dedup.insert(left_row) {
rows.push(left_row);
}
}

offsets.push(OffsetSize::usize_as(rows.len()));
dedup.clear();
}

if let Some(values) = converter.convert_rows(rows)?.get(0) {
Ok(GenericListArray::<OffsetSize>::new(
field.to_owned(),
OffsetBuffer::new(offsets.into()),
values.to_owned(),
l.nulls().cloned(),
))
} else {
internal_err!("array_except failed to convert rows")
}
}

pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return internal_err!("array_except needs two arguments");
}

let array1 = &args[0];
let array2 = &args[1];

match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
(DataType::List(field), DataType::List(_)) => {
check_datatypes("array_except", &[&array1, &array2])?;
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = general_except::<i32>(list1, list2, field)?;
Ok(Arc::new(result))
}
(DataType::LargeList(field), DataType::LargeList(_)) => {
check_datatypes("array_except", &[&array1, &array2])?;
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = general_except::<i64>(list1, list2, field)?;
Ok(Arc::new(result))
}
(dt1, dt2) => {
internal_err!("array_except got unexpected types: {dt1:?} and {dt2:?}")
}
}
}

pub fn array_slice(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let key = as_int64_array(&args[1])?;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayElement => {
Arc::new(|args| make_scalar_function(array_expressions::array_element)(args))
}
BuiltinScalarFunction::ArrayExcept => {
Arc::new(|args| make_scalar_function(array_expressions::array_except)(args))
}
BuiltinScalarFunction::ArrayLength => {
Arc::new(|args| make_scalar_function(array_expressions::array_length)(args))
}
Expand Down
5 changes: 3 additions & 2 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,9 @@ enum ScalarFunction {
ArrayUnion = 120;
OverLay = 121;
Range = 122;
ArrayPopFront = 123;
Levenshtein = 124;
ArrayExcept = 123;
ArrayPopFront = 124;
Levenshtein = 125;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 16 additions & 11 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ use datafusion_common::{
};
use datafusion_expr::{
abs, acos, acosh, array, array_append, array_concat, array_dims, array_element,
array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims,
array_position, array_positions, array_prepend, array_remove, array_remove_all,
array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n,
array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh,
bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin,
date_part, date_trunc, decode, degrees, digest, encode, exp,
array_except, array_has, array_has_all, array_has_any, array_intersect, array_length,
array_ndims, array_position, array_positions, array_prepend, array_remove,
array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all,
array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left,
levenshtein, ln, log, log10, log2,
Expand Down Expand Up @@ -465,6 +465,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayAppend => Self::ArrayAppend,
ScalarFunction::ArrayConcat => Self::ArrayConcat,
ScalarFunction::ArrayEmpty => Self::ArrayEmpty,
ScalarFunction::ArrayExcept => Self::ArrayExcept,
ScalarFunction::ArrayHasAll => Self::ArrayHasAll,
ScalarFunction::ArrayHasAny => Self::ArrayHasAny,
ScalarFunction::ArrayHas => Self::ArrayHas,
Expand Down Expand Up @@ -1352,6 +1353,10 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::ArrayExcept => Ok(array_except(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayHasAll => Ok(array_has_all(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
Expand All @@ -1364,6 +1369,10 @@ pub fn parse_expr(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayIntersect => Ok(array_intersect(
Copy link
Contributor Author

@jayzhan211 jayzhan211 Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing from #8081. Not sure why CI from #8081 does not catch this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to enhance the check during CI? maybe this is a little bug?

parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayPosition => Ok(array_position(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
Expand Down Expand Up @@ -1415,10 +1424,6 @@ pub fn parse_expr(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayIntersect => Ok(array_intersect(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::Range => Ok(gen_range(
args.to_owned()
.iter()
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend,
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty,
BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept,
BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll,
BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny,
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
Expand Down
Loading