Skip to content

Commit 6b56383

Browse files
committed
support from and to is negative
1 parent d606eef commit 6b56383

File tree

2 files changed

+62
-30
lines changed

2 files changed

+62
-30
lines changed

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ where
661661
let adjusted_zero_index = if index < 0 {
662662
// array_slice in duckdb with negative to_index is python-like, so index itself is exclusive
663663
if let Ok(index) = index.try_into() {
664-
index + len - O::usize_as(1)
664+
index + len
665665
} else {
666666
return exec_err!("array_slice got invalid index: {}", index);
667667
}
@@ -730,29 +730,28 @@ where
730730
} else if stride.is_negative() {
731731
// return empty array
732732
offsets.push(offsets[row_index]);
733-
break;
733+
continue;
734734
}
735-
let mut index = start;
735+
let mut index = start + from;
736736
let mut cnt = 0;
737737
let stride: O = stride.try_into().map_err(|_| {
738738
internal_datafusion_err!(
739739
"array_slice got invalid stride: {}",
740740
stride
741741
)
742742
})?;
743-
while index <= to {
744-
let start = (start + index).to_usize().ok_or_else(|| {
745-
internal_datafusion_err!(
746-
"array_slice got invalid index: {:?}",
747-
start + index
748-
)
749-
})?;
750-
mutable.extend(0, start, start + 1);
743+
while index <= start + to {
744+
mutable.extend(
745+
0,
746+
index.to_usize().unwrap(),
747+
index.to_usize().unwrap() + 1,
748+
);
751749
index += stride;
752750
cnt += 1;
753751
}
754752
offsets.push(offsets[row_index] + O::usize_as(cnt));
755753
} else {
754+
// stride is default to 1
756755
mutable.extend(
757756
0,
758757
(start + from).to_usize().unwrap(),
@@ -761,8 +760,36 @@ where
761760
offsets.push(offsets[row_index] + (to - from + O::usize_as(1)));
762761
}
763762
} else {
763+
let stride = if let Some(stride) = stride {
764+
if !stride.is_negative() {
765+
return exec_err!(
766+
"array_slice got invalid stride: {}, because start index < end index",
767+
stride
768+
);
769+
}
770+
stride
771+
} else {
772+
// return empty array
773+
offsets.push(offsets[row_index]);
774+
continue;
775+
};
776+
let stride: O = stride.try_into().map_err(|_| {
777+
internal_datafusion_err!("array_slice got invalid stride: {}", stride)
778+
})?;
779+
780+
let mut index = start + from;
781+
let mut cnt = 0;
782+
while index >= start + to {
783+
mutable.extend(
784+
0,
785+
index.to_usize().unwrap(),
786+
index.to_usize().unwrap() + 1,
787+
);
788+
index += stride;
789+
cnt += 1;
790+
}
764791
// invalid range, return empty array
765-
offsets.push(offsets[row_index]);
792+
offsets.push(offsets[row_index] + O::usize_as(cnt));
766793
}
767794
} else {
768795
// invalid range, return empty array

datafusion/sqllogictest/test_files/array.slt

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, -1), array_slice(make_array(
12271227
query error Execution error: array_slice got invalid stride: 0, it cannot be 0
12281228
select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, 0);
12291229

1230+
query ??
1231+
select array_slice(make_array(1, 2, 3, 4, 5), 5, 1, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 5, 1, -2);
1232+
----
1233+
[5, 3, 1] [o, l, h]
1234+
12301235
query ??
12311236
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2);
12321237
----
@@ -1335,12 +1340,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU
13351340
query ??
13361341
select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3);
13371342
----
1338-
[1] [h, e]
1343+
[1, 2] [h, e, l]
13391344

13401345
query ??
13411346
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3);
13421347
----
1343-
[1] [h, e]
1348+
[1, 2] [h, e, l]
13441349

13451350
# array_slice scalar function #13 (with negative number and NULL)
13461351
query error
@@ -1360,34 +1365,34 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU
13601365
query ??
13611366
select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1);
13621367
----
1363-
[2, 3, 4] [l, l]
1368+
[2, 3, 4, 5] [l, l, o]
13641369

13651370
query ??
13661371
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1);
13671372
----
1368-
[2, 3, 4] [l, l]
1373+
[2, 3, 4, 5] [l, l, o]
13691374

13701375
# array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array))
13711376
query ??
13721377
select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1);
13731378
----
1374-
[1, 2, 3, 4] [h, e, l, l]
1379+
[1, 2, 3, 4, 5] [h, e, l, l, o]
13751380

13761381
query ??
13771382
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1);
13781383
----
1379-
[1, 2, 3, 4] [h, e, l, l]
1384+
[1, 2, 3, 4, 5] [h, e, l, l, o]
13801385

13811386
# array_slice scalar function #17 (with negative indexes; first index = second index)
13821387
query ??
13831388
select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3);
13841389
----
1385-
[] []
1390+
[2] [l]
13861391

13871392
query ??
13881393
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3);
13891394
----
1390-
[] []
1395+
[2] [l]
13911396

13921397
# array_slice scalar function #18 (with negative indexes; first index > second_index)
13931398
query ??
@@ -1415,24 +1420,24 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7
14151420
query ??
14161421
select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1);
14171422
----
1418-
[[1, 2, 3, 4, 5]] []
1423+
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]]
14191424

14201425
query ??
14211426
select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1);
14221427
----
1423-
[[1, 2, 3, 4, 5]] []
1428+
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]]
14241429

14251430

14261431
# array_slice scalar function #21 (with first positive index and last negative index)
14271432
query ??
14281433
select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2);
14291434
----
1430-
[2] [e, l]
1435+
[2, 3] [e, l, l]
14311436

14321437
query ??
14331438
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2);
14341439
----
1435-
[2] [e, l]
1440+
[2, 3] [e, l, l]
14361441

14371442
# array_slice scalar function #22 (with first negative index and last positive index)
14381443
query ??
@@ -1461,7 +1466,7 @@ query ?
14611466
select array_slice(column1, column2, column3) from slices;
14621467
----
14631468
[]
1464-
[12, 13, 14, 15, 16]
1469+
[12, 13, 14, 15, 16, 17]
14651470
[]
14661471
[]
14671472
[]
@@ -1472,7 +1477,7 @@ query ?
14721477
select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices;
14731478
----
14741479
[]
1475-
[12, 13, 14, 15, 16]
1480+
[12, 13, 14, 15, 16, 17]
14761481
[]
14771482
[]
14781483
[]
@@ -1485,9 +1490,9 @@ query ???
14851490
select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices;
14861491
----
14871492
[1] [] [, 2, 3, 4, 5]
1488-
[] [13, 14, 15, 16] [12, 13, 14, 15]
1493+
[2] [13, 14, 15, 16, 17] [12, 13, 14, 15]
14891494
[] [] [21, 22, 23, , 25]
1490-
[] [33] []
1495+
[] [33, 34] []
14911496
[4, 5] [] []
14921497
[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45]
14931498
[5] [, 54, 55, 56, 57, 58, 59, 60] [55]
@@ -1496,9 +1501,9 @@ query ???
14961501
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices;
14971502
----
14981503
[1] [] [, 2, 3, 4, 5]
1499-
[] [13, 14, 15, 16] [12, 13, 14, 15]
1504+
[2] [13, 14, 15, 16, 17] [12, 13, 14, 15]
15001505
[] [] [21, 22, 23, , 25]
1501-
[] [33] []
1506+
[] [33, 34] []
15021507
[4, 5] [] []
15031508
[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45]
15041509
[5] [, 54, 55, 56, 57, 58, 59, 60] [55]

0 commit comments

Comments
 (0)