Skip to content

feat:implement sql style 'ends_with' and 'instr' string function #8862

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ async fn test_fn_initcap() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_instr() -> Result<()> {
let expr = instr(col("a"), lit("b"));

let expected = [
"+-------------------------+",
"| instr(test.a,Utf8(\"b\")) |",
"+-------------------------+",
"| 2 |",
"| 2 |",
"| 0 |",
"| 5 |",
"+-------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_left() -> Result<()> {
Expand Down Expand Up @@ -634,6 +654,26 @@ async fn test_fn_starts_with() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_ends_with() -> Result<()> {
let expr = ends_with(col("a"), lit("DEF"));

let expected = [
"+-------------------------------+",
"| ends_with(test.a,Utf8(\"DEF\")) |",
"+-------------------------------+",
"| true |",
"| false |",
"| false |",
"| false |",
"+-------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_strpos() -> Result<()> {
Expand Down
36 changes: 25 additions & 11 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ pub enum BuiltinScalarFunction {
DateTrunc,
/// date_bin
DateBin,
/// ends_with
EndsWith,
/// initcap
InitCap,
/// InStr
InStr,
/// left
Left,
/// lpad
Expand Down Expand Up @@ -446,7 +450,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::DatePart => Volatility::Immutable,
BuiltinScalarFunction::DateTrunc => Volatility::Immutable,
BuiltinScalarFunction::DateBin => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,
BuiltinScalarFunction::InStr => Volatility::Immutable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the function name we implement need to follow postgresql? If so, I will change the name of this function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark has both position and instr and they are not quite the same signature:
https://spark.apache.org/docs/latest/api/sql/#position
https://spark.apache.org/docs/latest/api/sql/#instr

BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
Expand Down Expand Up @@ -708,6 +714,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::InStr => {
utf8_to_int_type(&input_expr_types[0], "instr")
}
BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
BuiltinScalarFunction::Lower => {
utf8_to_str_type(&input_expr_types[0], "lower")
Expand Down Expand Up @@ -795,6 +804,7 @@ impl BuiltinScalarFunction {
true,
)))),
BuiltinScalarFunction::StartsWith => Ok(Boolean),
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
utf8_to_int_type(&input_expr_types[0], "strpos")
}
Expand Down Expand Up @@ -1211,17 +1221,19 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => {
Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
)
}

BuiltinScalarFunction::EndsWith
| BuiltinScalarFunction::InStr
| BuiltinScalarFunction::Strpos
| BuiltinScalarFunction::StartsWith => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
),

BuiltinScalarFunction::Substr => Signature::one_of(
vec![
Expand Down Expand Up @@ -1473,7 +1485,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::Chr => &["chr"],
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::InStr => &["instr"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input
scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
Expand Down Expand Up @@ -830,6 +831,7 @@ scalar_expr!(SHA512, sha512, string, "SHA-512 hash");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`");
scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
scalar_expr!(Substr, substr, string position, "substring from the `position` to the end");
scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters");
Expand Down Expand Up @@ -1372,6 +1374,7 @@ mod test {
test_scalar_expr!(Gcd, gcd, arg_1, arg_2);
test_scalar_expr!(Lcm, lcm, arg_1, arg_2);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(InStr, instr, string, substring);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
Expand Down Expand Up @@ -1410,6 +1413,7 @@ mod test {
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
test_scalar_expr!(Substr, substring, string, position, count);
Expand Down
143 changes: 142 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function initcap")
}
}),
BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::instr::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::instr::<i64>)(args)
}
other => internal_err!("Unsupported data type {other:?} for function instr"),
}),
BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left");
Expand Down Expand Up @@ -765,6 +774,17 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function starts_with")
}
}),
BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::ends_with::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::ends_with::<i64>)(args)
}
other => {
internal_err!("Unsupported data type {other:?} for function ends_with")
}
}),
BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
Expand Down Expand Up @@ -987,7 +1007,7 @@ mod tests {
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
Int32Array, StringArray, UInt64Array,
Int32Array, Int64Array, StringArray, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -1379,6 +1399,95 @@ mod tests {
Utf8,
StringArray
);
test_function!(
InStr,
&[lit("abc"), lit("b")],
Ok(Some(2)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("c")],
Ok(Some(3)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("d")],
Ok(Some(0)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("")],
Ok(Some(1)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("Helloworld"), lit("world")],
Ok(Some(6)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("Helloworld"), lit(ScalarValue::Utf8(None))],
Ok(None),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit(ScalarValue::Utf8(None)), lit("Hello")],
Ok(None),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))),
lit(ScalarValue::LargeUtf8(Some("world".to_string())))
],
Ok(Some(6)),
i64,
Int64,
Int64Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(None)),
lit(ScalarValue::LargeUtf8(Some("world".to_string())))
],
Ok(None),
i64,
Int64,
Int64Array
);
test_function!(
InStr,
&[
lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))),
lit(ScalarValue::LargeUtf8(None))
],
Ok(None),
i64,
Int64,
Int64Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Left,
Expand Down Expand Up @@ -2497,6 +2606,38 @@ mod tests {
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("alph"),],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("bet"),],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
test_function!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 for testing with NULL

EndsWith,
&[lit(ScalarValue::Utf8(None)), lit("alph"),],
Ok(None),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit(ScalarValue::Utf8(None)),],
Ok(None),
bool,
Boolean,
BooleanArray
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Strpos,
Expand Down
Loading