Skip to content

Commit 5e8d784

Browse files
committed
feat: support eq_array and to_array_of_size for FSL
1 parent 4b2d106 commit 5e8d784

File tree

1 file changed

+53
-10
lines changed

1 file changed

+53
-10
lines changed

datafusion/common/src/scalar.rs

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,7 @@ impl ScalarValue {
17251725
///
17261726
/// Errors if `self` is
17271727
/// - a decimal that fails be converted to a decimal array of size
1728-
/// - a `Fixedsizelist` that is not supported yet
1728+
/// - a `Fixedsizelist` that fails to be concatenated into an array of size
17291729
/// - a `List` that fails to be concatenated into an array of size
17301730
/// - a `Dictionary` that fails be converted to a dictionary array of size
17311731
pub fn to_array_of_size(&self, size: usize) -> Result<ArrayRef> {
@@ -1846,10 +1846,7 @@ impl ScalarValue {
18461846
.collect::<LargeBinaryArray>(),
18471847
),
18481848
},
1849-
ScalarValue::FixedSizeList(..) => {
1850-
return _not_impl_err!("FixedSizeList is not supported yet")
1851-
}
1852-
ScalarValue::List(arr) => {
1849+
ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => {
18531850
let arrays = std::iter::repeat(arr.as_ref())
18541851
.take(size)
18551852
.collect::<Vec<_>>();
@@ -2324,8 +2321,6 @@ impl ScalarValue {
23242321
///
23252322
/// Errors if
23262323
/// - it fails to downcast `array` to the data type of `self`
2327-
/// - `self` is a `Fixedsizelist`
2328-
/// - `self` is a `List`
23292324
/// - `self` is a `Struct`
23302325
///
23312326
/// # Panics
@@ -2398,10 +2393,10 @@ impl ScalarValue {
23982393
ScalarValue::LargeBinary(val) => {
23992394
eq_array_primitive!(array, index, LargeBinaryArray, val)?
24002395
}
2401-
ScalarValue::FixedSizeList(..) => {
2402-
return _not_impl_err!("FixedSizeList is not supported yet")
2396+
ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => {
2397+
let right = array.slice(index, 1);
2398+
arr == &right
24032399
}
2404-
ScalarValue::List(_) => return _not_impl_err!("List is not supported yet"),
24052400
ScalarValue::Date32(val) => {
24062401
eq_array_primitive!(array, index, Date32Array, val)?
24072402
}
@@ -3103,6 +3098,27 @@ mod tests {
31033098
assert_eq!(&arr, actual_list_arr);
31043099
}
31053100

3101+
#[test]
3102+
fn test_to_array_of_size_for_fsl() {
3103+
let values = Int32Array::from_iter([Some(1), None, Some(2)]);
3104+
let field = Arc::new(Field::new("item", DataType::Int32, true));
3105+
let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None);
3106+
let sv = ScalarValue::FixedSizeList(Arc::new(arr));
3107+
let actual_arr = sv
3108+
.to_array_of_size(2)
3109+
.expect("Failed to convert to array of size");
3110+
3111+
let expected_values =
3112+
Int32Array::from_iter([Some(1), None, Some(2), Some(1), None, Some(2)]);
3113+
let expected_arr =
3114+
FixedSizeListArray::new(field, 3, Arc::new(expected_values), None);
3115+
3116+
assert_eq!(
3117+
&expected_arr,
3118+
as_fixed_size_list_array(actual_arr.as_ref()).unwrap()
3119+
);
3120+
}
3121+
31063122
#[test]
31073123
fn test_list_to_array_string() {
31083124
let scalars = vec![
@@ -3181,6 +3197,33 @@ mod tests {
31813197
assert_eq!(result, &expected);
31823198
}
31833199

3200+
#[test]
3201+
fn test_list_scalar_eq_to_array() {
3202+
let list_array: ArrayRef =
3203+
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
3204+
Some(vec![Some(0), Some(1), Some(2)]),
3205+
None,
3206+
Some(vec![None, Some(5)]),
3207+
]));
3208+
3209+
let fsl_array: ArrayRef =
3210+
Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
3211+
vec![
3212+
Some(vec![Some(0), Some(1), Some(2)]),
3213+
None,
3214+
Some(vec![Some(3), None, Some(5)]),
3215+
],
3216+
3,
3217+
));
3218+
3219+
for arr in [list_array, fsl_array] {
3220+
for i in 0..arr.len() {
3221+
let scalar = ScalarValue::List(arr.slice(i, 1));
3222+
assert!(scalar.eq_array(&arr, i).unwrap());
3223+
}
3224+
}
3225+
}
3226+
31843227
#[test]
31853228
fn scalar_add_trait_test() -> Result<()> {
31863229
let float_value = ScalarValue::Float64(Some(123.));

0 commit comments

Comments
 (0)