-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Refactor UnwrapCastInComparison
to remove Expr
clones
#10115
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` | ||
|
||
use std::cmp::Ordering; | ||
use std::mem; | ||
use std::sync::Arc; | ||
|
||
use crate::optimizer::ApplyOrder; | ||
|
@@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; | |
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; | ||
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; | ||
use datafusion_expr::utils::merge_schema; | ||
use datafusion_expr::{ | ||
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, | ||
}; | ||
use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; | ||
|
||
/// [`UnwrapCastInComparison`] attempts to remove casts from | ||
/// comparisons to literals ([`ScalarValue`]s) by applying the casts | ||
|
@@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter { | |
impl TreeNodeRewriter for UnwrapCastExprRewriter { | ||
type Node = Expr; | ||
|
||
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> { | ||
match &expr { | ||
fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> { | ||
match &mut expr { | ||
// For case: | ||
// try_cast/cast(expr as data_type) op literal | ||
// literal op try_cast/cast(expr as data_type) | ||
Expr::BinaryExpr(BinaryExpr { left, op, right }) => { | ||
let left = left.as_ref().clone(); | ||
let right = right.as_ref().clone(); | ||
let left_type = left.get_type(&self.schema)?; | ||
let right_type = right.get_type(&self.schema)?; | ||
// Because the plan has been done the type coercion, the left and right must be equal | ||
if is_support_data_type(&left_type) | ||
&& is_support_data_type(&right_type) | ||
&& is_comparison_op(op) | ||
{ | ||
match (&left, &right) { | ||
( | ||
Expr::Literal(left_lit_value), | ||
Expr::TryCast(TryCast { expr, .. }) | ||
| Expr::Cast(Cast { expr, .. }), | ||
) => { | ||
// if the left_lit_value can be casted to the type of expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let expr_type = expr.get_type(&self.schema)?; | ||
let casted_scalar_value = | ||
try_cast_literal_to_type(left_lit_value, &expr_type)?; | ||
if let Some(value) = casted_scalar_value { | ||
// unwrap the cast/try_cast for the right expr | ||
return Ok(Transformed::yes(binary_expr( | ||
lit(value), | ||
*op, | ||
expr.as_ref().clone(), | ||
))); | ||
} | ||
} | ||
( | ||
Expr::TryCast(TryCast { expr, .. }) | ||
| Expr::Cast(Cast { expr, .. }), | ||
Expr::Literal(right_lit_value), | ||
) => { | ||
// if the right_lit_value can be casted to the type of expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let expr_type = expr.get_type(&self.schema)?; | ||
let casted_scalar_value = | ||
try_cast_literal_to_type(right_lit_value, &expr_type)?; | ||
if let Some(value) = casted_scalar_value { | ||
// unwrap the cast/try_cast for the left expr | ||
return Ok(Transformed::yes(binary_expr( | ||
expr.as_ref().clone(), | ||
*op, | ||
lit(value), | ||
))); | ||
} | ||
} | ||
(_, _) => { | ||
// do nothing | ||
} | ||
Expr::BinaryExpr(BinaryExpr { left, op, right }) | ||
if { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Today I learned you can add statements in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1. I used to wrap another function. |
||
let Ok(left_type) = left.get_type(&self.schema) else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
let Ok(right_type) = right.get_type(&self.schema) else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
is_support_data_type(&left_type) | ||
&& is_support_data_type(&right_type) | ||
&& is_comparison_op(op) | ||
} => | ||
{ | ||
match (left.as_mut(), right.as_mut()) { | ||
( | ||
Expr::Literal(left_lit_value), | ||
Expr::TryCast(TryCast { | ||
expr: right_expr, .. | ||
}) | ||
| Expr::Cast(Cast { | ||
expr: right_expr, .. | ||
}), | ||
) => { | ||
// if the left_lit_value can be casted to the type of expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let Ok(expr_type) = right_expr.get_type(&self.schema) else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
let Ok(Some(value)) = | ||
try_cast_literal_to_type(left_lit_value, &expr_type) | ||
else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
**left = lit(value); | ||
// unwrap the cast/try_cast for the right expr | ||
**right = | ||
mem::replace(right_expr, Expr::Literal(ScalarValue::Null)); | ||
Ok(Transformed::yes(expr)) | ||
} | ||
( | ||
Expr::TryCast(TryCast { | ||
expr: left_expr, .. | ||
}) | ||
| Expr::Cast(Cast { | ||
expr: left_expr, .. | ||
}), | ||
Expr::Literal(right_lit_value), | ||
) => { | ||
// if the right_lit_value can be casted to the type of expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let Ok(expr_type) = left_expr.get_type(&self.schema) else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
let Ok(Some(value)) = | ||
try_cast_literal_to_type(right_lit_value, &expr_type) | ||
else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
// unwrap the cast/try_cast for the left expr | ||
**left = | ||
mem::replace(left_expr, Expr::Literal(ScalarValue::Null)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to have a default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall I open a follow-up PR that adds default for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've opened a minor PR to add default for |
||
**right = lit(value); | ||
Ok(Transformed::yes(expr)) | ||
} | ||
_ => Ok(Transformed::no(expr)), | ||
} | ||
// return the new binary op | ||
Ok(Transformed::yes(binary_expr(left, *op, right))) | ||
} | ||
// For case: | ||
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3) | ||
Expr::InList(InList { | ||
expr: left_expr, | ||
list, | ||
negated, | ||
expr: left, list, .. | ||
}) => { | ||
if let Some( | ||
Expr::TryCast(TryCast { | ||
expr: internal_left_expr, | ||
.. | ||
}) | ||
| Expr::Cast(Cast { | ||
expr: internal_left_expr, | ||
.. | ||
}), | ||
) = Some(left_expr.as_ref()) | ||
{ | ||
let internal_left = internal_left_expr.as_ref().clone(); | ||
let internal_left_type = internal_left.get_type(&self.schema); | ||
if internal_left_type.is_err() { | ||
// error data type | ||
return Ok(Transformed::no(expr)); | ||
} | ||
let internal_left_type = internal_left_type?; | ||
if !is_support_data_type(&internal_left_type) { | ||
// not supported data type | ||
return Ok(Transformed::no(expr)); | ||
} | ||
let right_exprs = list | ||
.iter() | ||
.map(|right| { | ||
let right_type = right.get_type(&self.schema)?; | ||
if !is_support_data_type(&right_type) { | ||
return internal_err!( | ||
"The type of list expr {} not support", | ||
&right_type | ||
); | ||
} | ||
match right { | ||
Expr::Literal(right_lit_value) => { | ||
// if the right_lit_value can be casted to the type of internal_left_expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let casted_scalar_value = | ||
try_cast_literal_to_type(right_lit_value, &internal_left_type)?; | ||
if let Some(value) = casted_scalar_value { | ||
Ok(lit(value)) | ||
} else { | ||
internal_err!( | ||
"Can't cast the list expr {:?} to type {:?}", | ||
right_lit_value, &internal_left_type | ||
) | ||
} | ||
} | ||
other_expr => internal_err!( | ||
"Only support literal expr to optimize, but the expr is {:?}", | ||
&other_expr | ||
), | ||
} | ||
}) | ||
.collect::<Result<Vec<_>>>(); | ||
match right_exprs { | ||
Ok(right_exprs) => Ok(Transformed::yes(in_list( | ||
internal_left, | ||
right_exprs, | ||
*negated, | ||
))), | ||
Err(_) => Ok(Transformed::no(expr)), | ||
} | ||
} else { | ||
Ok(Transformed::no(expr)) | ||
let (Expr::TryCast(TryCast { | ||
expr: left_expr, .. | ||
}) | ||
| Expr::Cast(Cast { | ||
expr: left_expr, .. | ||
})) = left.as_mut() | ||
else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
let Ok(expr_type) = left_expr.get_type(&self.schema) else { | ||
return Ok(Transformed::no(expr)); | ||
}; | ||
if !is_support_data_type(&expr_type) { | ||
return Ok(Transformed::no(expr)); | ||
} | ||
let Ok(right_exprs) = list | ||
.iter() | ||
.map(|right| { | ||
let right_type = right.get_type(&self.schema)?; | ||
if !is_support_data_type(&right_type) { | ||
internal_err!( | ||
"The type of list expr {} is not supported", | ||
&right_type | ||
)?; | ||
} | ||
match right { | ||
Expr::Literal(right_lit_value) => { | ||
// if the right_lit_value can be casted to the type of internal_left_expr | ||
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal | ||
let Ok(Some(value)) = try_cast_literal_to_type(right_lit_value, &expr_type) else { | ||
internal_err!( | ||
"Can't cast the list expr {:?} to type {:?}", | ||
right_lit_value, &expr_type | ||
)? | ||
}; | ||
Ok(lit(value)) | ||
} | ||
other_expr => internal_err!( | ||
"Only support literal expr to optimize, but the expr is {:?}", | ||
&other_expr | ||
), | ||
} | ||
}) | ||
.collect::<Result<Vec<_>>>() else { | ||
return Ok(Transformed::no(expr)) | ||
}; | ||
**left = mem::replace(left_expr, Expr::Literal(ScalarValue::Null)); | ||
*list = right_exprs; | ||
Ok(Transformed::yes(expr)) | ||
} | ||
// TODO: handle other expr type and dfs visit them | ||
_ => Ok(Transformed::no(expr)), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying a few different patterns to remove
Expr
clones in this rule. Thematch &mut expr
way seems to be most straightforward asexpr
is still available after pattern matching so we can return it unchanged if needed (return Ok(Transformed::no(expr));
) but we can also mutate some parts of theexpr
if needed.