Skip to content

Commit 31dbcb1

Browse files
committed
fix columns and add sqllogictest
Signed-off-by: veeupup <code@tanweime.com>
1 parent 36b1cc9 commit 31dbcb1

File tree

2 files changed

+167
-68
lines changed

2 files changed

+167
-68
lines changed

datafusion/core/tests/sqllogictests/test_files/array.slt

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,55 @@ AS VALUES
140140
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
141141
;
142142

143+
statement ok
144+
CREATE TABLE array_intersect_table_1D
145+
AS VALUES
146+
(make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)),
147+
(make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33))
148+
;
149+
150+
statement ok
151+
CREATE TABLE array_intersect_table_1D_Float
152+
AS VALUES
153+
(make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)),
154+
(make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33))
155+
;
156+
157+
statement ok
158+
CREATE TABLE array_intersect_table_1D_Boolean
159+
AS VALUES
160+
(make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)),
161+
(make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true))
162+
;
163+
164+
statement ok
165+
CREATE TABLE array_intersect_table_1D_UTF8
166+
AS VALUES
167+
(make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')),
168+
(make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow'))
169+
;
170+
171+
statement ok
172+
CREATE TABLE array_intersect_table_2D
173+
AS VALUES
174+
(make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])),
175+
(make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10]))
176+
;
177+
178+
statement ok
179+
CREATE TABLE array_intersect_table_2D_float
180+
AS VALUES
181+
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])),
182+
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3]))
183+
;
184+
185+
statement ok
186+
CREATE TABLE array_intersect_table_3D
187+
AS VALUES
188+
(make_array([[1,2]]), make_array([[1]])),
189+
(make_array([[1,2]]), make_array([[1,2]]))
190+
;
191+
143192
statement ok
144193
CREATE TABLE arrays_values_without_nulls
145194
AS VALUES
@@ -1695,14 +1744,74 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
16951744
----
16961745
true false true false false false true true false false true false true
16971746

1698-
query ????
1747+
query ???
1748+
select array_intersect(column1, column2),
1749+
array_intersect(column3, column4),
1750+
array_intersect(column5, column6)
1751+
from array_intersect_table_1D;
1752+
----
1753+
[1] [1, 3] [1, 3]
1754+
[11] [11, 33] [11, 33]
1755+
1756+
query ???
1757+
select array_intersect(column1, column2),
1758+
array_intersect(column3, column4),
1759+
array_intersect(column5, column6)
1760+
from array_intersect_table_1D_Float;
1761+
----
1762+
[1.0] [1.0, 3.0] []
1763+
[] [2.0] [1.11]
1764+
1765+
query ???
1766+
select array_intersect(column1, column2),
1767+
array_intersect(column3, column4),
1768+
array_intersect(column5, column6)
1769+
from array_intersect_table_1D_Boolean;
1770+
----
1771+
[] [false, true] [false]
1772+
[false] [true] [true]
1773+
1774+
query ???
1775+
select array_intersect(column1, column2),
1776+
array_intersect(column3, column4),
1777+
array_intersect(column5, column6)
1778+
from array_intersect_table_1D_UTF8;
1779+
----
1780+
[bc] [arrow, rust] []
1781+
[] [arrow, datafusion, rust] [arrow, rust]
1782+
1783+
query ??
1784+
select array_intersect(column1, column2),
1785+
array_intersect(column3, column4)
1786+
from array_intersect_table_2D;
1787+
----
1788+
[] [[4, 5], [6, 7]]
1789+
[[3, 4]] [[5, 6, 7], [8, 9, 10]]
1790+
1791+
query ?
1792+
select array_intersect(column1, column2)
1793+
from array_intersect_table_2D_float;
1794+
----
1795+
[[1.1, 2.2], [3.3]]
1796+
[[1.1, 2.2], [3.3]]
1797+
1798+
query ?
1799+
select array_intersect(column1, column2)
1800+
from array_intersect_table_3D;
1801+
----
1802+
[]
1803+
[[[1, 2]]]
1804+
1805+
query ??????
16991806
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
17001807
array_intersect(make_array(1,3,5), make_array(2,4,6)),
17011808
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
1702-
array_intersect(make_array(true, false), make_array(true))
1809+
array_intersect(make_array(true, false), make_array(true)),
1810+
array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
1811+
array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))
17031812
;
17041813
----
1705-
[2, 3] [] [cc, aa] [true]
1814+
[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]]
17061815

17071816
query BBBB
17081817
select list_has_all(make_array(1,2,3), make_array(4,5,6)),
@@ -1843,6 +1952,27 @@ drop table array_has_table_2D_float;
18431952
statement ok
18441953
drop table array_has_table_3D;
18451954

1955+
statement ok
1956+
drop table array_intersect_table_1D;
1957+
1958+
statement ok
1959+
drop table array_intersect_table_1D_Float;
1960+
1961+
statement ok
1962+
drop table array_intersect_table_1D_Boolean;
1963+
1964+
statement ok
1965+
drop table array_intersect_table_1D_UTF8;
1966+
1967+
statement ok
1968+
drop table array_intersect_table_2D;
1969+
1970+
statement ok
1971+
drop table array_intersect_table_2D_float;
1972+
1973+
statement ok
1974+
drop table array_intersect_table_3D;
1975+
18461976
statement ok
18471977
drop table arrays_values_without_nulls;
18481978

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,83 +1829,52 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
18291829
let first_array = as_list_array(&args[0])?;
18301830
let second_array = as_list_array(&args[1])?;
18311831

1832-
let dt = match (first_array.value_type(), second_array.value_type()) {
1833-
// (DataType::List(_), DataType::List(_)) => concat_internal(args)?,
1834-
(DataType::Utf8, DataType::Utf8) => DataType::Utf8,
1835-
(DataType::LargeUtf8, DataType::LargeUtf8) => DataType::LargeUtf8,
1836-
(DataType::Boolean, DataType::Boolean) => DataType::Boolean,
1837-
(DataType::Float32, DataType::Float32) => DataType::Float32,
1838-
(DataType::Float64, DataType::Float64) => DataType::Float64,
1839-
(DataType::Int8, DataType::Int8) => DataType::Int8,
1840-
(DataType::Int16, DataType::Int16) => DataType::Int16,
1841-
(DataType::Int32, DataType::Int32) => DataType::Int32,
1842-
(DataType::Int64, DataType::Int64) => DataType::Int64,
1843-
(DataType::UInt8, DataType::UInt8) => DataType::UInt8,
1844-
(DataType::UInt16, DataType::UInt16) => DataType::UInt16,
1845-
(DataType::UInt32, DataType::UInt32) => DataType::UInt32,
1846-
(DataType::UInt64, DataType::UInt64) => DataType::UInt64,
1847-
// (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)),
1848-
(first_value_dt, second_value_dt) =>
1849-
return Err(DataFusionError::NotImplemented(format!(
1850-
"array_intersect is not implemented for '{first_value_dt:?}' and '{second_value_dt:?}'",
1851-
)))
1852-
};
1832+
if first_array.value_type() != second_array.value_type() {
1833+
return Err(DataFusionError::NotImplemented(format!(
1834+
"array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'",
1835+
)));
1836+
}
1837+
let dt = first_array.value_type().clone();
18531838

18541839
let mut offsets = vec![0];
1855-
18561840
let mut tmp_values = vec![];
18571841

18581842
let mut converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
18591843
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
1860-
match (first_arr, second_arr) {
1861-
(Some(first_arr), Some(second_arr)) => {
1862-
let l_values = converter.convert_columns(&[first_arr])?;
1863-
let r_values = converter.convert_columns(&[second_arr])?;
1864-
1865-
let mut values_set = HashSet::new();
1866-
1867-
for (l_w, r_w) in first_array
1868-
.offsets()
1869-
.windows(2)
1870-
.zip(second_array.offsets().windows(2))
1871-
{
1872-
let l_slice = l_w[0]..l_w[1];
1873-
let r_slice = r_w[0]..r_w[1];
1874-
1875-
l_slice.for_each(|i| {
1876-
values_set.insert(l_values.row(i as usize));
1877-
});
1878-
1879-
let mut rows = vec![];
1880-
for i in r_slice {
1881-
let idx = i as usize;
1882-
if values_set.contains(&r_values.row(idx)) {
1883-
rows.push(r_values.row(idx));
1884-
}
1885-
}
1844+
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
1845+
let l_values = converter.convert_columns(&[first_arr])?;
1846+
let r_values = converter.convert_columns(&[second_arr])?;
18861847

1887-
offsets.push(rows.len() as i32);
1888-
let tmp_value = converter.convert_rows(rows)?;
1889-
tmp_values.push(
1890-
tmp_value
1891-
.get(0)
1892-
.ok_or_else(|| {
1893-
DataFusionError::Internal(format!(
1894-
"array_intersect: failed to get value from rows"
1895-
))
1896-
})?
1897-
.clone(),
1898-
);
1899-
values_set.clear();
1900-
}
1848+
let mut values_set = HashSet::with_capacity(l_values.num_rows());
1849+
for l_val in l_values.iter() {
1850+
values_set.insert(l_val);
19011851
}
1902-
_ => {
1903-
todo!()
1852+
let mut rows = Vec::with_capacity(r_values.num_rows());
1853+
for r_val in r_values.iter().sorted().dedup() {
1854+
if values_set.contains(&r_val) {
1855+
rows.push(r_val);
1856+
}
19041857
}
1858+
1859+
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
1860+
DataFusionError::Internal(format!("offsets should not be empty"))
1861+
})?;
1862+
offsets.push(last_offset + rows.len() as i32);
1863+
let tmp_value = converter.convert_rows(rows)?;
1864+
tmp_values.push(
1865+
tmp_value
1866+
.get(0)
1867+
.ok_or_else(|| {
1868+
DataFusionError::Internal(format!(
1869+
"array_intersect: failed to get value from rows"
1870+
))
1871+
})?
1872+
.clone(),
1873+
);
19051874
}
19061875
}
19071876

1908-
let field = Arc::new(Field::new("item_list", dt, true));
1877+
let field = Arc::new(Field::new("item", dt, true));
19091878
let offsets = OffsetBuffer::new(offsets.into());
19101879
let tmp_values_ref = tmp_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
19111880
let values = concat(&tmp_values_ref)?;

0 commit comments

Comments
 (0)