Skip to content

Refactor UnwrapCastInComparison to remove Expr clones #10115

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 117 additions & 126 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`

use std::cmp::Ordering;
use std::mem;
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
Expand All @@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator};

/// [`UnwrapCastInComparison`] attempts to remove casts from
/// comparisons to literals ([`ScalarValue`]s) by applying the casts
Expand Down Expand Up @@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter {
impl TreeNodeRewriter for UnwrapCastExprRewriter {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
match &expr {
fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
match &mut expr {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying a few different patterns to remove Expr clones in this rule. The match &mut expr way seems to be most straightforward as expr is still available after pattern matching so we can return it unchanged if needed (return Ok(Transformed::no(expr));) but we can also mutate some parts of the expr if needed.

// For case:
// try_cast/cast(expr as data_type) op literal
// literal op try_cast/cast(expr as data_type)
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let left = left.as_ref().clone();
let right = right.as_ref().clone();
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
// Because the plan has been done the type coercion, the left and right must be equal
if is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
{
match (&left, &right) {
(
Expr::Literal(left_lit_value),
Expr::TryCast(TryCast { expr, .. })
| Expr::Cast(Cast { expr, .. }),
) => {
// if the left_lit_value can be casted to the type of expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
try_cast_literal_to_type(left_lit_value, &expr_type)?;
if let Some(value) = casted_scalar_value {
// unwrap the cast/try_cast for the right expr
return Ok(Transformed::yes(binary_expr(
lit(value),
*op,
expr.as_ref().clone(),
)));
}
}
(
Expr::TryCast(TryCast { expr, .. })
| Expr::Cast(Cast { expr, .. }),
Expr::Literal(right_lit_value),
) => {
// if the right_lit_value can be casted to the type of expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &expr_type)?;
if let Some(value) = casted_scalar_value {
// unwrap the cast/try_cast for the left expr
return Ok(Transformed::yes(binary_expr(
expr.as_ref().clone(),
*op,
lit(value),
)));
}
}
(_, _) => {
// do nothing
}
Expr::BinaryExpr(BinaryExpr { left, op, right })
if {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today I learned you can add statements in a {} as part of a match clause. Nice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I used to wrap another function.

let Ok(left_type) = left.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
let Ok(right_type) = right.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
} =>
{
match (left.as_mut(), right.as_mut()) {
(
Expr::Literal(left_lit_value),
Expr::TryCast(TryCast {
expr: right_expr, ..
})
| Expr::Cast(Cast {
expr: right_expr, ..
}),
) => {
// if the left_lit_value can be casted to the type of expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Ok(expr_type) = right_expr.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
let Ok(Some(value)) =
try_cast_literal_to_type(left_lit_value, &expr_type)
else {
return Ok(Transformed::no(expr));
};
**left = lit(value);
// unwrap the cast/try_cast for the right expr
**right =
mem::replace(right_expr, Expr::Literal(ScalarValue::Null));
Ok(Transformed::yes(expr))
}
(
Expr::TryCast(TryCast {
expr: left_expr, ..
})
| Expr::Cast(Cast {
expr: left_expr, ..
}),
Expr::Literal(right_lit_value),
) => {
// if the right_lit_value can be casted to the type of expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Ok(expr_type) = left_expr.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
let Ok(Some(value)) =
try_cast_literal_to_type(right_lit_value, &expr_type)
else {
return Ok(Transformed::no(expr));
};
// unwrap the cast/try_cast for the left expr
**left =
mem::replace(left_expr, Expr::Literal(ScalarValue::Null));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a default Expr so that we could use mem::take().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall I open a follow-up PR that adds default for Expr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mem::take is a nice thing to avoid cloning, I think I used it to avoid String cloning other day, and we can have a default Expr with panicking inside I'd say...

Copy link
Contributor Author

@peter-toth peter-toth Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened a minor PR to add default for Expr: #10127, but as Expr is just an enum, I don't know how to panic inside it.

**right = lit(value);
Ok(Transformed::yes(expr))
}
_ => Ok(Transformed::no(expr)),
}
// return the new binary op
Ok(Transformed::yes(binary_expr(left, *op, right)))
}
// For case:
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
Expr::InList(InList {
expr: left_expr,
list,
negated,
expr: left, list, ..
}) => {
if let Some(
Expr::TryCast(TryCast {
expr: internal_left_expr,
..
})
| Expr::Cast(Cast {
expr: internal_left_expr,
..
}),
) = Some(left_expr.as_ref())
{
let internal_left = internal_left_expr.as_ref().clone();
let internal_left_type = internal_left.get_type(&self.schema);
if internal_left_type.is_err() {
// error data type
return Ok(Transformed::no(expr));
}
let internal_left_type = internal_left_type?;
if !is_support_data_type(&internal_left_type) {
// not supported data type
return Ok(Transformed::no(expr));
}
let right_exprs = list
.iter()
.map(|right| {
let right_type = right.get_type(&self.schema)?;
if !is_support_data_type(&right_type) {
return internal_err!(
"The type of list expr {} not support",
&right_type
);
}
match right {
Expr::Literal(right_lit_value) => {
// if the right_lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
if let Some(value) = casted_scalar_value {
Ok(lit(value))
} else {
internal_err!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &internal_left_type
)
}
}
other_expr => internal_err!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
),
}
})
.collect::<Result<Vec<_>>>();
match right_exprs {
Ok(right_exprs) => Ok(Transformed::yes(in_list(
internal_left,
right_exprs,
*negated,
))),
Err(_) => Ok(Transformed::no(expr)),
}
} else {
Ok(Transformed::no(expr))
let (Expr::TryCast(TryCast {
expr: left_expr, ..
})
| Expr::Cast(Cast {
expr: left_expr, ..
})) = left.as_mut()
else {
return Ok(Transformed::no(expr));
};
let Ok(expr_type) = left_expr.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
if !is_support_data_type(&expr_type) {
return Ok(Transformed::no(expr));
}
let Ok(right_exprs) = list
.iter()
.map(|right| {
let right_type = right.get_type(&self.schema)?;
if !is_support_data_type(&right_type) {
internal_err!(
"The type of list expr {} is not supported",
&right_type
)?;
}
match right {
Expr::Literal(right_lit_value) => {
// if the right_lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Ok(Some(value)) = try_cast_literal_to_type(right_lit_value, &expr_type) else {
internal_err!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &expr_type
)?
};
Ok(lit(value))
}
other_expr => internal_err!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
),
}
})
.collect::<Result<Vec<_>>>() else {
return Ok(Transformed::no(expr))
};
**left = mem::replace(left_expr, Expr::Literal(ScalarValue::Null));
*list = right_exprs;
Ok(Transformed::yes(expr))
}
// TODO: handle other expr type and dfs visit them
_ => Ok(Transformed::no(expr)),
Expand Down
Loading