From 22b55f4ffe95da299756a5dd850598871328e4fb Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 17 Apr 2024 12:25:25 +0200 Subject: [PATCH] Refactor `UnwrapCastInComparison` to remove `Expr` clones --- .../src/unwrap_cast_in_comparison.rs | 243 +++++++++--------- 1 file changed, 117 insertions(+), 126 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 5ede43a05134..bd14584fd5c1 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,6 +18,7 @@ //! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` use std::cmp::Ordering; +use std::mem; use std::sync::Arc; use crate::optimizer::ApplyOrder; @@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{ - binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, -}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter { impl TreeNodeRewriter for UnwrapCastExprRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result> { - match &expr { + fn f_up(&mut self, mut expr: Expr) -> Result> { + match &mut expr { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = left.as_ref().clone(); - let right = right.as_ref().clone(); - let left_type = left.get_type(&self.schema)?; - let right_type = right.get_type(&self.schema)?; - // Because the plan has been done the type coercion, the left and right must be equal - if is_support_data_type(&left_type) - && is_support_data_type(&right_type) - && is_comparison_op(op) - { - match (&left, &right) { - ( - Expr::Literal(left_lit_value), - Expr::TryCast(TryCast { expr, .. }) - | Expr::Cast(Cast { expr, .. }), - ) => { - // if the left_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; - let casted_scalar_value = - try_cast_literal_to_type(left_lit_value, &expr_type)?; - if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the right expr - return Ok(Transformed::yes(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - ))); - } - } - ( - Expr::TryCast(TryCast { expr, .. }) - | Expr::Cast(Cast { expr, .. }), - Expr::Literal(right_lit_value), - ) => { - // if the right_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; - let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &expr_type)?; - if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the left expr - return Ok(Transformed::yes(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - ))); - } - } - (_, _) => { - // do nothing - } + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if { + let Ok(left_type) = left.get_type(&self.schema) else { + return Ok(Transformed::no(expr)); }; + let Ok(right_type) = right.get_type(&self.schema) else { + return Ok(Transformed::no(expr)); + }; + is_support_data_type(&left_type) + && is_support_data_type(&right_type) + && is_comparison_op(op) + } => + { + match (left.as_mut(), right.as_mut()) { + ( + Expr::Literal(left_lit_value), + Expr::TryCast(TryCast { + expr: right_expr, .. + }) + | Expr::Cast(Cast { + expr: right_expr, .. + }), + ) => { + // if the left_lit_value can be casted to the type of expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let Ok(expr_type) = right_expr.get_type(&self.schema) else { + return Ok(Transformed::no(expr)); + }; + let Ok(Some(value)) = + try_cast_literal_to_type(left_lit_value, &expr_type) + else { + return Ok(Transformed::no(expr)); + }; + **left = lit(value); + // unwrap the cast/try_cast for the right expr + **right = + mem::replace(right_expr, Expr::Literal(ScalarValue::Null)); + Ok(Transformed::yes(expr)) + } + ( + Expr::TryCast(TryCast { + expr: left_expr, .. + }) + | Expr::Cast(Cast { + expr: left_expr, .. + }), + Expr::Literal(right_lit_value), + ) => { + // if the right_lit_value can be casted to the type of expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let Ok(expr_type) = left_expr.get_type(&self.schema) else { + return Ok(Transformed::no(expr)); + }; + let Ok(Some(value)) = + try_cast_literal_to_type(right_lit_value, &expr_type) + else { + return Ok(Transformed::no(expr)); + }; + // unwrap the cast/try_cast for the left expr + **left = + mem::replace(left_expr, Expr::Literal(ScalarValue::Null)); + **right = lit(value); + Ok(Transformed::yes(expr)) + } + _ => Ok(Transformed::no(expr)), } - // return the new binary op - Ok(Transformed::yes(binary_expr(left, *op, right))) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList(InList { - expr: left_expr, - list, - negated, + expr: left, list, .. }) => { - if let Some( - Expr::TryCast(TryCast { - expr: internal_left_expr, - .. - }) - | Expr::Cast(Cast { - expr: internal_left_expr, - .. - }), - ) = Some(left_expr.as_ref()) - { - let internal_left = internal_left_expr.as_ref().clone(); - let internal_left_type = internal_left.get_type(&self.schema); - if internal_left_type.is_err() { - // error data type - return Ok(Transformed::no(expr)); - } - let internal_left_type = internal_left_type?; - if !is_support_data_type(&internal_left_type) { - // not supported data type - return Ok(Transformed::no(expr)); - } - let right_exprs = list - .iter() - .map(|right| { - let right_type = right.get_type(&self.schema)?; - if !is_support_data_type(&right_type) { - return internal_err!( - "The type of list expr {} not support", - &right_type - ); - } - match right { - Expr::Literal(right_lit_value) => { - // if the right_lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &internal_left_type)?; - if let Some(value) = casted_scalar_value { - Ok(lit(value)) - } else { - internal_err!( - "Can't cast the list expr {:?} to type {:?}", - right_lit_value, &internal_left_type - ) - } - } - other_expr => internal_err!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ), - } - }) - .collect::>>(); - match right_exprs { - Ok(right_exprs) => Ok(Transformed::yes(in_list( - internal_left, - right_exprs, - *negated, - ))), - Err(_) => Ok(Transformed::no(expr)), - } - } else { - Ok(Transformed::no(expr)) + let (Expr::TryCast(TryCast { + expr: left_expr, .. + }) + | Expr::Cast(Cast { + expr: left_expr, .. + })) = left.as_mut() + else { + return Ok(Transformed::no(expr)); + }; + let Ok(expr_type) = left_expr.get_type(&self.schema) else { + return Ok(Transformed::no(expr)); + }; + if !is_support_data_type(&expr_type) { + return Ok(Transformed::no(expr)); } + let Ok(right_exprs) = list + .iter() + .map(|right| { + let right_type = right.get_type(&self.schema)?; + if !is_support_data_type(&right_type) { + internal_err!( + "The type of list expr {} is not supported", + &right_type + )?; + } + match right { + Expr::Literal(right_lit_value) => { + // if the right_lit_value can be casted to the type of internal_left_expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let Ok(Some(value)) = try_cast_literal_to_type(right_lit_value, &expr_type) else { + internal_err!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &expr_type + )? + }; + Ok(lit(value)) + } + other_expr => internal_err!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ), + } + }) + .collect::>>() else { + return Ok(Transformed::no(expr)) + }; + **left = mem::replace(left_expr, Expr::Literal(ScalarValue::Null)); + *list = right_exprs; + Ok(Transformed::yes(expr)) } // TODO: handle other expr type and dfs visit them _ => Ok(Transformed::no(expr)),