Skip to content

Commit

Permalink
[SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR refactors `TryCast` to use `RuntimeReplaceable`, so that we don't need `CastBase` anymore. The unit tests are also simplified because we don't need to check the execution of `RuntimeReplaceable`, but only the analysis behavior.

### Why are the changes needed?

code cleanup

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #36703 from cloud-fan/cast.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Jun 8, 2022
1 parent f2f73ed commit 46175d1
Show file tree
Hide file tree
Showing 13 changed files with 347 additions and 451 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,29 +425,54 @@ object Cast {
}
}

abstract class CastBase extends UnaryExpression
with TimeZoneAwareExpression
with NullIntolerant
with SupportQueryContext {
/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""",
since = "1.0.0",
group = "conversion_funcs")
case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {

def child: Expression
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)

def dataType: DataType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

/**
* Returns true iff we can cast `from` type to `to` type.
*/
def canCast(from: DataType, to: DataType): Boolean
override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)

/**
* Returns the error message if casting from one type to another one is invalid.
*/
def typeCheckFailureMessage: String
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

override def toString: String = s"cast($child as ${dataType.simpleString})"
private def typeCheckFailureMessage: String = if (ansiEnabled) {
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
} else {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.ANSI_ENABLED.key -> "false"))
}
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}

override def checkInputDataTypes(): TypeCheckResult = {
if (canCast(child.dataType, dataType)) {
if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
Expand All @@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression

override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)

protected def ansiEnabled: Boolean

override def initQueryContext(): String = if (ansiEnabled) {
origin.context
} else {
Expand All @@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

override lazy val preCanonicalized: Expression = {
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
if (timeZoneId.isDefined && !needsTimeZone) {
basic.withTimeZone(null)
} else {
Expand Down Expand Up @@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression
"""
}

override def toString: String = s"cast($child as ${dataType.simpleString})"

override def sql: String = dataType match {
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
// type of casting can only be introduced by the analyzer, and can be omitted when converting
Expand All @@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression
}
}

/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""",
since = "1.0.0",
group = "conversion_funcs")
case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends CastBase {

def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
Cast.canAnsiCast(from, to)
} else {
Cast.canCast(from, to)
}

override def typeCheckFailureMessage: String = if (ansiEnabled) {
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
} else {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.ANSI_ENABLED.key -> "false"))
}
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}

override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
}

/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
copy(child = newChild)
}

/**
* A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
* value of one data type into another data type), but returns a NULL value instead of raising an
* error when the conversion can not be performed.
*/
@ExpressionDescription(
usage = """
_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
true, except it returns NULL instead of raising an error. Note that the behavior of this
expression doesn't depend on configuration `spark.sql.ansi.enabled`.
""",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
> SELECT _FUNC_(1234567890123L as int);
null
""",
since = "3.2.0",
group = "conversion_funcs")
case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

// When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
// Otherwise behave like Expression.resolved.
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
(!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)

override lazy val replacement = {
TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
}

// If the target data type is a complex type which can't have Null values, we should guarantee
// that the casting between the element types won't produce Null results.
private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
canCast(fromField.dataType, toField.dataType) &&
resolvableNullability(
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
toField.nullable)
}

case _ =>
Cast.canAnsiCast(from, to)
}

override def checkInputDataTypes(): TypeCheckResult = {
if (canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
}
}

override def toString: String = s"try_cast($child as ${dataType.simpleString})"

override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(child = newChild)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
true
case _: CastBase => true
case _: Cast => true
case _: TryEval => true
case _: GetDateField | _: LastDay => true
case _: ExtractIntervalPart[_] => true
case _: ArraySetLike => true
Expand Down
Loading

0 comments on commit 46175d1

Please # to comment.