-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Support compute return types from argument values (not just their DataTypes) #8985
Changes from 2 commits
4e06013
17a2c91
02f2284
0c9acdd
56b71ae
3dbc0c7
491a4a1
468b38f
5772d9f
59b3958
f195fba
21d495f
b2e8457
4efb395
a9546ee
040d319
93b72ee
e0add48
653577f
7993af8
a3b9648
2121770
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use std::any::Any; | ||
|
||
use arrow_schema::{Field, Schema}; | ||
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; | ||
|
||
use datafusion::error::Result; | ||
use datafusion::prelude::*; | ||
use datafusion_common::{ | ||
internal_err, DFSchema, DataFusionError, ScalarValue, ToDFSchema, | ||
}; | ||
use datafusion_expr::{ | ||
expr::ScalarFunction, ColumnarValue, ExprSchemable, ScalarUDF, ScalarUDFImpl, | ||
Signature, | ||
}; | ||
|
||
#[derive(Debug)] | ||
struct UDFWithExprReturn { | ||
signature: Signature, | ||
} | ||
|
||
impl UDFWithExprReturn { | ||
fn new() -> Self { | ||
Self { | ||
signature: Signature::any(3, Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
||
//Implement the ScalarUDFImpl trait for UDFWithExprReturn | ||
impl ScalarUDFImpl for UDFWithExprReturn { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
fn name(&self) -> &str { | ||
"udf_with_expr_return" | ||
} | ||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
fn return_type(&self, _: &[DataType]) -> Result<DataType> { | ||
Ok(DataType::Int32) | ||
} | ||
// An example of how to use the exprs to determine the return type | ||
// If the third argument is '0', return the type of the first argument | ||
// If the third argument is '1', return the type of the second argument | ||
fn return_type_from_exprs( | ||
&self, | ||
arg_exprs: &[Expr], | ||
schema: &DFSchema, | ||
) -> Result<DataType> { | ||
if arg_exprs.len() != 3 { | ||
return internal_err!("The size of the args must be 3."); | ||
} | ||
let take_idx = match arg_exprs.get(2).unwrap() { | ||
Expr::Literal(ScalarValue::Int64(Some(idx))) if (idx == &0 || idx == &1) => { | ||
*idx as usize | ||
} | ||
_ => unreachable!(), | ||
}; | ||
arg_exprs.get(take_idx).unwrap().get_type(schema) | ||
} | ||
// The actual implementation would add one to the argument | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
unimplemented!() | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct UDFDefault { | ||
signature: Signature, | ||
} | ||
|
||
impl UDFDefault { | ||
fn new() -> Self { | ||
Self { | ||
signature: Signature::any(3, Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
||
// Implement the ScalarUDFImpl trait for UDFDefault | ||
// This is the same as UDFWithExprReturn, except without return_type_from_exprs | ||
impl ScalarUDFImpl for UDFDefault { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
fn name(&self) -> &str { | ||
"udf_default" | ||
} | ||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
fn return_type(&self, _: &[DataType]) -> Result<DataType> { | ||
Ok(DataType::Boolean) | ||
} | ||
// The actual implementation would add one to the argument | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
unimplemented!() | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this example is missing actually using the function in a query / dataframe. As @Weijun-H pointed out the logic added to What I think the example needs to do is someething like
So for example, a good example function might be a function that takes a string argument Then for example run queries like select my_cast(c1, 'i32'), arrow_typeof(my_cast(c1, 'i32')); -- returns value and DataType::Int32
select my_cast(c1, 'i64'), arrow_typeof(my_cast(c1, 'i64')); -- returns value and DataType::Int64 Does that make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I realized that it missing using the function. |
||
// Create a new ScalarUDF from the implementation | ||
let udf_with_expr_return = ScalarUDF::from(UDFWithExprReturn::new()); | ||
|
||
// Call 'return_type' to get the return type of the function | ||
let ret = udf_with_expr_return.return_type(&[DataType::Int32])?; | ||
assert_eq!(ret, DataType::Int32); | ||
|
||
let schema = Schema::new(vec![ | ||
Field::new("a", DataType::Float32, false), | ||
Field::new("b", DataType::Float64, false), | ||
]) | ||
.to_dfschema()?; | ||
|
||
// Set the third argument to 0 to return the type of the first argument | ||
let expr0 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(0_i64)]); | ||
let args = match expr0 { | ||
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, | ||
_ => panic!("Expected ScalarFunction"), | ||
}; | ||
let ret = udf_with_expr_return.return_type_from_exprs(&args, &schema)?; | ||
// The return type should be the same as the first argument | ||
assert_eq!(ret, DataType::Float32); | ||
|
||
// Set the third argument to 1 to return the type of the second argument | ||
let expr1 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(1_i64)]); | ||
let args1 = match expr1 { | ||
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, | ||
_ => panic!("Expected ScalarFunction"), | ||
}; | ||
let ret = udf_with_expr_return.return_type_from_exprs(&args1, &schema)?; | ||
// The return type should be the same as the second argument | ||
assert_eq!(ret, DataType::Float64); | ||
|
||
// Create a new ScalarUDF from the implementation | ||
let udf_default = ScalarUDF::from(UDFDefault::new()); | ||
// Call 'return_type' to get the return type of the function | ||
let ret = udf_default.return_type(&[DataType::Int32])?; | ||
assert_eq!(ret, DataType::Boolean); | ||
|
||
// Set the third argument to 0 to return the type of the first argument | ||
let expr2 = udf_default.call(vec![col("a"), col("b"), lit(0_i64)]); | ||
let args = match expr2 { | ||
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, | ||
_ => panic!("Expected ScalarFunction"), | ||
}; | ||
let ret = udf_default.return_type_from_exprs(&args, &schema)?; | ||
assert_eq!(ret, DataType::Boolean); | ||
|
||
Ok(()) | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,12 +17,13 @@ | |||||||||||||||||||||
|
||||||||||||||||||||||
//! [`ScalarUDF`]: Scalar User Defined Functions | ||||||||||||||||||||||
|
||||||||||||||||||||||
use crate::ExprSchemable; | ||||||||||||||||||||||
use crate::{ | ||||||||||||||||||||||
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, | ||||||||||||||||||||||
ScalarFunctionImplementation, Signature, | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
use arrow::datatypes::DataType; | ||||||||||||||||||||||
use datafusion_common::Result; | ||||||||||||||||||||||
use datafusion_common::{DFSchema, Result}; | ||||||||||||||||||||||
use std::any::Any; | ||||||||||||||||||||||
use std::fmt; | ||||||||||||||||||||||
use std::fmt::Debug; | ||||||||||||||||||||||
|
@@ -152,6 +153,17 @@ impl ScalarUDF { | |||||||||||||||||||||
self.inner.return_type(args) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
/// The datatype this function returns given the input argument input types. | ||||||||||||||||||||||
/// This function is used when the input arguments are [`Expr`]s. | ||||||||||||||||||||||
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. | ||||||||||||||||||||||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
pub fn return_type_from_exprs( | ||||||||||||||||||||||
&self, | ||||||||||||||||||||||
args: &[Expr], | ||||||||||||||||||||||
schema: &DFSchema, | ||||||||||||||||||||||
) -> Result<DataType> { | ||||||||||||||||||||||
self.inner.return_type_from_exprs(args, schema) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
/// Invoke the function on `args`, returning the appropriate result. | ||||||||||||||||||||||
/// | ||||||||||||||||||||||
/// See [`ScalarUDFImpl::invoke`] for more details. | ||||||||||||||||||||||
|
@@ -249,6 +261,22 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |||||||||||||||||||||
/// the arguments | ||||||||||||||||||||||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>; | ||||||||||||||||||||||
|
||||||||||||||||||||||
/// What [`DataType`] will be returned by this function, given the types of | ||||||||||||||||||||||
/// the expr arguments | ||||||||||||||||||||||
fn return_type_from_exprs( | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we need to use the we could change the trait impl to something like this pub trait ScalarUDFImpl: Debug + Send + Sync {
/// What [`DataType`] will be returned by this function, given the types of
/// the expr arguments
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
) -> Option<Result<DataType>> {
// The default implementation returns None
// so that people don't have to implement `return_type_from_exprs` if they dont want to
None
}
} then change the impl ScalarUDF
/// The datatype this function returns given the input argument input types.
/// This function is used when the input arguments are [`Expr`]s.
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
pub fn return_type_from_exprs<S: ExprSchema>(
&self,
args: &[Expr],
schema: &S,
) -> Result<DataType> {
// If the implementation provides a return_type_from_exprs, use it
if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) {
return_type
// Otherwise, use the return_type function
} else {
let arg_types = args
.iter()
.map(|arg| arg.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
}
}
} this way we don't need to constrain the ScalarFunctionDefinition::UDF(fun) => {
Ok(fun.return_type_from_exprs(&args, schema)?)
} and it still makes It does make it very slightly less ergonomic as end users now need to wrap their body in an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works well on my side. Thanks! Another question for me is, how can user implement
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah @yyy1000 you still run into the same error then. I'm wondering if it'd be easiest to just change the type signature on pub trait ExprSchemable<S: ExprSchema> {
/// given a schema, return the type of the expr
fn get_type(&self, schema: &S) -> Result<DataType>;
/// given a schema, return the nullability of the expr
fn nullable(&self, input_schema: &S) -> Result<bool>;
/// given a schema, return the expr's optional metadata
fn metadata(&self, schema: &S) -> Result<HashMap<String, String>>;
/// convert to a field with respect to a schema
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
/// cast to a type with respect to a schema
fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}
impl ExprSchemable<DFSchema> for Expr {
//...
} then the trait can just go back to the original implementation you had using fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &DFSchema,
) -> Result<DataType> {
let arg_types = arg_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
} I tried this locally and was able to get things to compile locally, and was able to implement a udf using the trait. It does make it a little less flexible as it's expecting a I think the only other approach would be to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your help! @universalmind303 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update is I changed the signature to take |
||||||||||||||||||||||
&self, | ||||||||||||||||||||||
arg_exprs: &[Expr], | ||||||||||||||||||||||
schema: &DFSchema, | ||||||||||||||||||||||
) -> Result<DataType> { | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
@Weijun-H If I'd like to make a change like this, there's an error 🥲
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am unfamiliar with this part. It seems you need to extend There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe you can use a trait object like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alamb Sorry for the delay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool -- thanks -- I'll try and find some time to play around with it but it may be a while There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yyy1000 the following changes worked for me. If anyone else have a better solution please suggest it. I am still a newbie in rust but trying to help to learn. expr_schema.rs ![]() ![]() return_types_udf.rs ![]() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those look like great changes to me @brayanjuls -- 👌 What do you think @yyy1000 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for help! @alamb @brayanjuls |
||||||||||||||||||||||
// provide default implementation that calls `self.return_type()` | ||||||||||||||||||||||
// so that people don't have to implement `return_type_from_exprs` if they dont want to | ||||||||||||||||||||||
let arg_types = arg_exprs | ||||||||||||||||||||||
.iter() | ||||||||||||||||||||||
.map(|e| e.get_type(schema)) | ||||||||||||||||||||||
.collect::<Result<Vec<_>>>()?; | ||||||||||||||||||||||
self.return_type(&arg_types) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
/// Invoke the function on `args`, returning the appropriate result | ||||||||||||||||||||||
/// | ||||||||||||||||||||||
/// The function will be invoked passed with the slice of [`ColumnarValue`] | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty confusing I think -- as it seems inconsistent with the return_type_from_exprs