Skip to content

Handle type coercion in signature for ApproxPercentileCont #12274

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2433,26 +2433,26 @@ mod tests {
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
"| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"| -85 | -101 | 14 | -12.0 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25.0 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32.75 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56.0 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29.75 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25.0 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35.75 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37.75 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17.0 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70.5 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43.0 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17.0 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25.0 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12.0 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11.25 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34.5 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17.0 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25.0 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
],
&df
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ async fn test_fn_approx_median() -> Result<()> {
"+-----------------------+",
"| approx_median(test.b) |",
"+-----------------------+",
"| 10 |",
"| 10.0 |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a change in behavior -- with this PR now median always returns float but before it returned the same type as its input

This comment was marked as outdated.

Copy link
Contributor Author

@jayzhan211 jayzhan211 Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I change the result to f64 now.

I think it is fine to have f64 for median value. I check the result of Duckdb, they have double for integer, although they have decimal for decimal input, but since we doesn't support decimal for approx_median so there is no regression. We could support decimal case later on

"+-----------------------+",
];

Expand All @@ -366,7 +366,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+---------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5)) |",
"+---------------------------------------------+",
"| 10 |",
"| 10.0 |",
"+---------------------------------------------+",
];

Expand All @@ -387,7 +387,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+--------------------------------------+",
"| approx_percentile_cont(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"| 10.0 |",
"+--------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
Expand All @@ -400,7 +400,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |",
"+------------------------------------------------------+",
"| 30 |",
"| 30.25 |",
"+------------------------------------------------------+",
];

Expand Down
15 changes: 7 additions & 8 deletions datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ use std::fmt::Debug;
use arrow::{datatypes::DataType, datatypes::Field};
use arrow_schema::DataType::{Float64, UInt64};

use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -63,7 +62,10 @@ impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -97,11 +99,8 @@ impl AggregateUDFImpl for ApproxMedian {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("ApproxMedian requires numeric input types");
}
Ok(arg_types[0].clone())
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Expand Down
173 changes: 36 additions & 137 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,22 @@ use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use arrow::array::{Array, RecordBatch};
use arrow::array::{Array, AsArray, RecordBatch};
use arrow::compute::{filter, is_not_null};
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
datatypes::DataType,
};
use arrow::datatypes::Float64Type;
use arrow::{array::ArrayRef, datatypes::DataType};
use arrow_schema::{Field, Schema};

use datafusion_common::{
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
DataFusionError, Result, ScalarValue,
exec_err, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result,
ScalarValue,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature,
Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
Expand Down Expand Up @@ -84,21 +75,8 @@ impl Default for ApproxPercentileCont {
impl ApproxPercentileCont {
/// Create a new [`ApproxPercentileCont`] aggregate function.
pub fn new() -> Self {
let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with a float64 percentile
for num in NUMERICS {
variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
// Additionally accept an integer number of centroids for T-Digest
for int in INTEGERS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
DataType::Float64,
int.clone(),
]))
}
}
Self {
signature: Signature::one_of(variants, Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}

Expand Down Expand Up @@ -156,15 +134,12 @@ fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
let percentile = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::Float32(Some(value)) => {
value as f64
}
ScalarValue::Float64(Some(value)) => {
value
}
sv => {
return not_impl_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
return internal_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' should be coerced to f64 (got data type {})",
sv.data_type()
)
}
Expand All @@ -182,17 +157,10 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
let max_size = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::UInt8(Some(q)) => q as usize,
ScalarValue::UInt16(Some(q)) => q as usize,
ScalarValue::UInt32(Some(q)) => q as usize,
ScalarValue::UInt64(Some(q)) => q as usize,
ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
sv => {
return not_impl_err!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' should be coerced to u64 literal (got data type {}).",
sv.data_type()
)
},
Expand Down Expand Up @@ -257,16 +225,26 @@ impl AggregateUDFImpl for ApproxPercentileCont {
Ok(Box::new(self.create_accumulator(acc_args)?))
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("approx_percentile_cont requires numeric input types");
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() == 3 {
// Since float is coercible to u64 in `can_cast_types`, we check whether it is integer
if !arg_types[2].is_integer() {
return exec_err!(
"3rd argument should be integer but got {}",
arg_types[2]
);
}
}
if arg_types.len() == 3 && !arg_types[2].is_integer() {
return plan_err!(
"approx_percentile_cont requires integer max_size input types"
);

if arg_types.len() == 2 {
Ok(vec![DataType::Float64; 2])
} else {
Ok(vec![DataType::Float64, DataType::Float64, DataType::UInt64])
}
Ok(arg_types[0].clone())
}
}

Expand Down Expand Up @@ -306,91 +284,8 @@ impl ApproxPercentileAccumulator {

// public for approx_percentile_cont_with_weight
pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
match values.data_type() {
DataType::Float64 => {
let array = downcast_value!(values, Float64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Float32 => {
let array = downcast_value!(values, Float32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int64 => {
let array = downcast_value!(values, Int64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int32 => {
let array = downcast_value!(values, Int32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int16 => {
let array = downcast_value!(values, Int16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int8 => {
let array = downcast_value!(values, Int8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt64 => {
let array = downcast_value!(values, UInt64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt32 => {
let array = downcast_value!(values, UInt32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt16 => {
let array = downcast_value!(values, UInt16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt8 => {
let array = downcast_value!(values, UInt8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
e => internal_err!(
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
),
}
let array = values.as_primitive::<Float64Type>();
Ok(array.values().as_ref().to_vec())
}
}

Expand All @@ -406,7 +301,11 @@ impl Accumulator for ApproxPercentileAccumulator {
values = filter(&values, &is_not_null(&values)?)?;
}
let sorted_values = &arrow::compute::sort(&values, None)?;
let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
let sorted_values = sorted_values
.as_primitive::<Float64Type>()
.values()
.as_ref()
.to_vec();
self.digest = self.digest.merge_sorted_f64(&sorted_values);
Ok(())
}
Expand Down
Loading