From 46175d1362062035fb93f87f25d61a9b711359ab Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 Jun 2022 09:50:08 +0800 Subject: [PATCH] [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable ### What changes were proposed in this pull request? This PR refactors `TryCast` to use `RuntimeReplaceable`, so that we don't need `CastBase` anymore. The unit tests are also simplified because we don't need to check the execution of `RuntimeReplaceable`, but only the analysis behavior. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #36703 from cloud-fan/cast. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/Cast.scala | 114 ++++----- .../sql/catalyst/expressions/TryCast.scala | 122 ---------- .../sql/catalyst/expressions/TryEval.scala | 86 ++++++- .../sql/catalyst/optimizer/expressions.scala | 3 +- .../sql/catalyst/parser/AstBuilder.scala | 10 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/catalyst/util/package.scala | 2 +- .../catalyst/expressions/CastSuiteBase.scala | 217 +++++++++--------- ...Suite.scala => CastWithAnsiOffSuite.scala} | 20 +- ...teBase.scala => CastWithAnsiOnSuite.scala} | 149 ++++-------- .../catalyst/expressions/TryCastSuite.scala | 67 ++++-- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 4 +- 13 files changed, 347 insertions(+), 451 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{CastSuite.scala => CastWithAnsiOffSuite.scala} (98%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{AnsiCastSuiteBase.scala => CastWithAnsiOnSuite.scala} (85%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 6ed25f5e45eb8..497261be2e446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -425,29 +425,54 @@ object Cast { } } -abstract class CastBase extends UnaryExpression - with TimeZoneAwareExpression - with NullIntolerant - with SupportQueryContext { +/** + * Cast the child expression to the target data type. + * + * When cast from/to timezone related types, we need timeZoneId, which will be resolved with + * session local timezone by an analyzer [[ResolveTimeZone]]. + */ +@ExpressionDescription( + usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.", + examples = """ + Examples: + > SELECT _FUNC_('10' as int); + 10 + """, + since = "1.0.0", + group = "conversion_funcs") +case class Cast( + child: Expression, + dataType: DataType, + timeZoneId: Option[String] = None, + ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression + with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext { - def child: Expression + def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) = + this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled) - def dataType: DataType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) - /** - * Returns true iff we can cast `from` type to `to` type. - */ - def canCast(from: DataType, to: DataType): Boolean + override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild) - /** - * Returns the error message if casting from one type to another one is invalid. - */ - def typeCheckFailureMessage: String + final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST) - override def toString: String = s"cast($child as ${dataType.simpleString})" + private def typeCheckFailureMessage: String = if (ansiEnabled) { + if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) { + Cast.typeCheckFailureMessage(child.dataType, dataType, + Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString)) + } else { + Cast.typeCheckFailureMessage(child.dataType, dataType, + Some(SQLConf.ANSI_ENABLED.key -> "false")) + } + } else { + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" + } override def checkInputDataTypes(): TypeCheckResult = { - if (canCast(child.dataType, dataType)) { + if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage) @@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType) - protected def ansiEnabled: Boolean - override def initQueryContext(): String = if (ansiEnabled) { origin.context } else { @@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) override lazy val preCanonicalized: Expression = { - val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase] + val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast] if (timeZoneId.isDefined && !needsTimeZone) { basic.withTimeZone(null) } else { @@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression """ } + override def toString: String = s"cast($child as ${dataType.simpleString})" + override def sql: String = dataType match { // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this // type of casting can only be introduced by the analyzer, and can be omitted when converting @@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression } } -/** - * Cast the child expression to the target data type. - * - * When cast from/to timezone related types, we need timeZoneId, which will be resolved with - * session local timezone by an analyzer [[ResolveTimeZone]]. - */ -@ExpressionDescription( - usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.", - examples = """ - Examples: - > SELECT _FUNC_('10' as int); - 10 - """, - since = "1.0.0", - group = "conversion_funcs") -case class Cast( - child: Expression, - dataType: DataType, - timeZoneId: Option[String] = None, - override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends CastBase { - - def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) = - this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled) - - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) - - final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST) - - override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) { - Cast.canAnsiCast(from, to) - } else { - Cast.canCast(from, to) - } - - override def typeCheckFailureMessage: String = if (ansiEnabled) { - if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) { - Cast.typeCheckFailureMessage(child.dataType, dataType, - Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString)) - } else { - Cast.typeCheckFailureMessage(child.dataType, dataType, - Some(SQLConf.ANSI_ENABLED.key -> "false")) - } - } else { - s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" - } - - override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild) -} - /** * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala deleted file mode 100644 index 9ac6329f28170..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability} -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} - -/** - * A special version of [[AnsiCast]]. It performs the same operation (i.e. converts a value of - * one data type into another data type), but returns a NULL value instead of raising an error - * when the conversion can not be performed. - * - * When cast from/to timezone related types, we need timeZoneId, which will be resolved with - * session local timezone by an analyzer [[ResolveTimeZone]]. - */ -@ExpressionDescription( - usage = """ - _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`. - This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as - true, except it returns NULL instead of raising an error. Note that the behavior of this - expression doesn't depend on configuration `spark.sql.ansi.enabled`. - """, - examples = """ - Examples: - > SELECT _FUNC_('10' as int); - 10 - > SELECT _FUNC_(1234567890123L as int); - null - """, - since = "3.2.0", - group = "conversion_funcs") -case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None) - extends CastBase { - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) - - // Here we force `ansiEnabled` as true so that we can reuse the evaluation code branches which - // throw exceptions on conversion failures. - override protected val ansiEnabled: Boolean = true - - override def nullable: Boolean = true - - // If the target data type is a complex type which can't have Null values, we should guarantee - // that the casting between the element types won't produce Null results. - override def canCast(from: DataType, to: DataType): Boolean = (from, to) match { - case (ArrayType(fromType, fn), ArrayType(toType, tn)) => - canCast(fromType, toType) && - resolvableNullability(fn || forceNullable(fromType, toType), tn) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - canCast(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - canCast(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { - case (fromField, toField) => - canCast(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) - } - - case _ => - Cast.canAnsiCast(from, to) - } - - override def cast(from: DataType, to: DataType): Any => Any = (input: Any) => - try { - super.cast(from, to)(input) - } catch { - case _: Exception => - null - } - - override def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, - result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { - val javaType = JavaCode.javaType(resultType) - code""" - boolean $resultIsNull = $inputIsNull; - $javaType $result = ${CodeGenerator.defaultValue(resultType)}; - if (!$inputIsNull) { - try { - ${cast(input, result, resultIsNull)} - } catch (Exception e) { - $resultIsNull = true; - } - } - """ - } - - override def typeCheckFailureMessage: String = - Cast.typeCheckFailureMessage(child.dataType, dataType, None) - - override protected def withNewChildInternal(newChild: Expression): TryCast = - copy(child = newChild) - - override def toString: String = { - s"try_cast($child as ${dataType.simpleString})" - } - - override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index c179c83befb4c..dc5bcae4c08a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran copy(child = newChild) } +/** + * A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a + * value of one data type into another data type), but returns a NULL value instead of raising an + * error when the conversion can not be performed. + */ +@ExpressionDescription( + usage = """ + _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`. + This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as + true, except it returns NULL instead of raising an error. Note that the behavior of this + expression doesn't depend on configuration `spark.sql.ansi.enabled`. + """, + examples = """ + Examples: + > SELECT _FUNC_('10' as int); + 10 + > SELECT _FUNC_(1234567890123L as int); + null + """, + since = "3.2.0", + group = "conversion_funcs") +case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None) + extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) + + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && + (!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined) + + override lazy val replacement = { + TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true)) + } + + // If the target data type is a complex type which can't have Null values, we should guarantee + // that the casting between the element types won't produce Null results. + private def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => + Cast.canAnsiCast(from, to) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None)) + } + } + + override def toString: String = s"try_cast($child as ${dataType.simpleString})" + + override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})" + + override protected def withNewChildInternal(newChild: Expression): Expression = + this.copy(child = newChild) +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 62c328a29a821..3fc23c31ac74d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length => true - case _: CastBase => true + case _: Cast => true + case _: TryEval => true case _: GetDateField | _: LastDay => true case _: ExtractIntervalPart[_] => true case _: ArraySetLike => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7ae04010ad259..46847411bf0ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1788,15 +1788,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { val rawDataType = typedVisit[DataType](ctx.dataType()) val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) - val cast = ctx.name.getType match { + ctx.name.getType match { case SqlBaseParser.CAST => - Cast(expression(ctx.expression), dataType) + val cast = Cast(expression(ctx.expression), dataType) + cast.setTagValue(Cast.USER_SPECIFIED_CAST, true) + cast case SqlBaseParser.TRY_CAST => + // `TryCast` can only be user-specified and we don't need to set the USER_SPECIFIED_CAST + // tag, which is only used by `Cast` TryCast(expression(ctx.expression), dataType) } - cast.setTagValue(Cast.USER_SPECIFIED_CAST, true) - cast } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 677bdf2733612..11d6829402399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -590,7 +590,7 @@ case class View( // See more details in `SessionCatalog.fromCatalogTable`. private def canRemoveProject(p: Project): Boolean = { p.output.length == p.child.output.length && p.projectList.zip(p.child.output).forall { - case (Alias(cast: CastBase, name), childAttr) => + case (Alias(cast: Cast, name), childAttr) => cast.child match { case a: AttributeReference => a.dataType == cast.dataType && a.name == name && childAttr.semanticEquals(a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index e06072cbed282..f73fc7c681611 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -119,7 +119,7 @@ package object util extends Logging { PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) case r: InheritAnalysisRules => PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)), r.dataType) - case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) => + case c: Cast if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) => PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType) case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index ba8ab708046d1..ca492e11226b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -43,11 +43,19 @@ import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String /** - * Common test suite for [[Cast]], [[AnsiCast]] and [[TryCast]] expressions. + * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work + * for both ansi on and off. */ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { - protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase + protected def ansiEnabled: Boolean + + protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { + v match { + case lit: Expression => Cast(lit, targetType, timeZoneId, ansiEnabled) + case _ => Cast(Literal(v), targetType, timeZoneId, ansiEnabled) + } + } // expected cannot be null protected def checkCast(v: Any, expected: Any): Unit = { @@ -58,7 +66,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) } - protected def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = { + protected def verifyCastFailure(c: Cast, optionalExpectedMsg: Option[String] = None): Unit = { val typeCheckResult = c.checkInputDataTypes() assert(typeCheckResult.isFailure) assert(typeCheckResult.isInstanceOf[TypeCheckFailure]) @@ -66,20 +74,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { if (optionalExpectedMsg.isDefined) { assert(message.contains(optionalExpectedMsg.get)) - } else if (setConfigurationHint.nonEmpty) { - assert(message.contains("with ANSI mode on")) - assert(message.contains(setConfigurationHint)) } else { assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined) + if (ansiEnabled) { + assert(message.contains("with ANSI mode on")) + assert(message.contains(s"set ${SQLConf.ANSI_ENABLED.key} as false")) + } } } - // Whether the test suite is for TryCast. If yes, there is no exceptions and the result is - // always nullable. - protected def isTryCast: Boolean = false - - protected def setConfigurationHint: String = "" - test("null cast") { import DataTypeTestUtils._ @@ -281,8 +284,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } test("cast from string") { - assert(cast("abcdef", StringType).nullable === isTryCast) - assert(cast("abcdef", BinaryType).nullable === isTryCast) + assert(!cast("abcdef", StringType).nullable) + assert(!cast("abcdef", BinaryType).nullable) assert(cast("abcdef", BooleanType).nullable) assert(cast("abcdef", TimestampType).nullable) assert(cast("abcdef", LongType).nullable) @@ -981,14 +984,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { DayTimeIntervalType()), StringType), ansiInterval) } - if (!isTryCast) { - Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND", - "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval => - val e = intercept[ArithmeticException] { - cast(Literal.create(interval), DayTimeIntervalType()).eval() - }.getMessage - assert(e.contains("long overflow")) - } + Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND", + "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval => + val e = intercept[ArithmeticException] { + cast(Literal.create(interval), DayTimeIntervalType()).eval() + }.getMessage + assert(e.contains("long overflow")) } Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1, @@ -1027,15 +1028,13 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { YearMonthIntervalType()), StringType), ansiInterval) } - if (!isTryCast) { - Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") - .foreach { interval => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), YearMonthIntervalType()).eval() - }.getMessage - assert(e.contains("Error parsing interval year-month string: integer overflow")) - } - } + Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") + .foreach { interval => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), YearMonthIntervalType()).eval() + }.getMessage + assert(e.contains("Error parsing interval year-month string: integer overflow")) + } Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) .foreach { period => @@ -1098,9 +1097,27 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } - if (!isTryCast) { - Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval => - val dataType = YearMonthIntervalType() + Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval => + val dataType = YearMonthIntervalType() + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), dataType).eval() + }.getMessage + assert(e.contains(s"Interval string does not match year-month format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval")) + } + Seq(("1", YearMonthIntervalType(YEAR, MONTH)), + ("1", YearMonthIntervalType(YEAR, MONTH)), + ("1-1", YearMonthIntervalType(YEAR)), + ("1-1", YearMonthIntervalType(MONTH)), + ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)), + ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)), + ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)), + ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)), + ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)), + ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH))) + .foreach { case (interval, dataType) => val e = intercept[IllegalArgumentException] { cast(Literal.create(interval), dataType).eval() }.getMessage @@ -1109,26 +1126,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { .map(format => s"`$format`").mkString(", ")} " + s"when cast to ${dataType.typeName}: $interval")) } - Seq(("1", YearMonthIntervalType(YEAR, MONTH)), - ("1", YearMonthIntervalType(YEAR, MONTH)), - ("1-1", YearMonthIntervalType(YEAR)), - ("1-1", YearMonthIntervalType(MONTH)), - ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)), - ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)), - ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)), - ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)), - ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)), - ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH))) - .foreach { case (interval, dataType) => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), dataType).eval() - }.getMessage - assert(e.contains(s"Interval string does not match year-month format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${dataType.typeName}: $interval")) - } - } } test("SPARK-35735: Take into account day-time interval fields in cast") { @@ -1218,63 +1215,61 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(interval), dataType), dt) } - if (!isTryCast) { - Seq( - ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)), - ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)), - ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), - ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)), - ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)), - ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)), - - ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)), - ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)), - ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), - ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)), - ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)), - ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)), - ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)), - ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)), - ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)), - ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)), - ("1.23", DayTimeIntervalType(DAY)), - ("1.23", DayTimeIntervalType(HOUR)), - ("1.23", DayTimeIntervalType(MINUTE)), - ("1.23", DayTimeIntervalType(MINUTE))) - .foreach { case (interval, dataType) => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), dataType).eval() - }.getMessage - assert(e.contains(s"Interval string does not match day-time format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${dataType.typeName}: $interval, " + - s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + - "to restore the behavior before Spark 3.0.")) - } + Seq( + ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)), + ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)), + ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), + ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)), + ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)), + ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)), + + ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)), + ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)), + ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), + ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)), + ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)), + ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)), + ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)), + ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)), + ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)), + ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)), + ("1.23", DayTimeIntervalType(DAY)), + ("1.23", DayTimeIntervalType(HOUR)), + ("1.23", DayTimeIntervalType(MINUTE)), + ("1.23", DayTimeIntervalType(MINUTE))) + .foreach { case (interval, dataType) => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), dataType).eval() + }.getMessage + assert(e.contains(s"Interval string does not match day-time format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval, " + + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + + "to restore the behavior before Spark 3.0.")) + } - // Check first field outof bound - Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)), - ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)), - ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), - ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()), - ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)), - ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)), - ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)), - ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)), - ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)), - ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND))) - .foreach { case (interval, dataType) => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), dataType).eval() - }.getMessage - assert(e.contains(s"Interval string does not match day-time format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${dataType.typeName}: $interval, " + - s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + - "to restore the behavior before Spark 3.0.")) - } - } + // Check first field outof bound + Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)), + ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)), + ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)), + ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()), + ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)), + ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)), + ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)), + ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)), + ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)), + ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND))) + .foreach { case (interval, dataType) => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), dataType).eval() + }.getMessage + assert(e.contains(s"Interval string does not match day-time format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval, " + + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + + "to restore the behavior before Spark 3.0.")) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala similarity index 98% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala index 630c45adba1b3..4e4bc096deac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala @@ -35,26 +35,10 @@ import org.apache.spark.unsafe.types.UTF8String /** * Test suite for data type casting expression [[Cast]] with ANSI mode disabled. - * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them - * in `CastSuiteBase` instead of this file to ensure the test coverage. */ -class CastSuite extends CastSuiteBase { - override def beforeAll(): Unit = { - super.beforeAll() - SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false) - } - - override def afterAll(): Unit = { - super.afterAll() - SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) - } +class CastWithAnsiOffSuite extends CastSuiteBase { - override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { - v match { - case lit: Expression => Cast(lit, targetType, timeZoneId) - case _ => Cast(Literal(v), targetType, timeZoneId) - } - } + override def ansiEnabled: Boolean = false test("null cast #2") { import DataTypeTestUtils._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala similarity index 85% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala index 84f0d5c59aaea..f2cfc52998462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala @@ -27,19 +27,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Test suite base for - * 1. [[Cast]] with ANSI mode enabled - * 2. [[AnsiCast]] - * 3. [[TryCast]] - * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them - * in `CastSuiteBase` instead of this file to ensure the test coverage. + * Test suite for data type casting expression [[Cast]] with ANSI mode enabled. */ -abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { +class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase { + + override def ansiEnabled: Boolean = true private def testIntMaxAndMin(dt: DataType): Unit = { assert(Seq(IntegerType, ShortType, ByteType).contains(dt)) @@ -339,25 +335,21 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved == !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[SparkRuntimeException]( - ret, """cannot be cast to "BOOLEAN"""") - } + assert(ret.resolved) + checkExceptionInExpression[SparkRuntimeException]( + ret, """cannot be cast to "BOOLEAN"""") } } test("cast from array III") { - if (!isTryCast) { - val from: DataType = ArrayType(DoubleType, containsNull = false) - val array = Literal.create(Seq(1.0, 2.0), from) - val to: DataType = ArrayType(IntegerType, containsNull = false) - val answer = Literal.create(Seq(1, 2), to).value - checkEvaluation(cast(array, to), answer) + val from: DataType = ArrayType(DoubleType, containsNull = false) + val array = Literal.create(Seq(1.0, 2.0), from) + val to: DataType = ArrayType(IntegerType, containsNull = false) + val answer = Literal.create(Seq(1, 2), to).value + checkEvaluation(cast(array, to), answer) - val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from) - checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow") - } + val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from) + checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow") } test("cast from map II") { @@ -386,48 +378,40 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { { val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(ret.resolved == !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[NumberFormatException]( - ret, - castErrMsg("a", IntegerType)) - } + assert(ret.resolved) + checkExceptionInExpression[NumberFormatException]( + ret, + castErrMsg("a", IntegerType)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved == !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[SparkRuntimeException]( - ret, - castErrMsg("123", BooleanType)) - } + assert(ret.resolved) + checkExceptionInExpression[SparkRuntimeException]( + ret, + castErrMsg("123", BooleanType)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(ret.resolved == !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[NumberFormatException]( - ret, - castErrMsg("a", IntegerType)) - } + assert(ret.resolved) + checkExceptionInExpression[NumberFormatException]( + ret, + castErrMsg("a", IntegerType)) } } test("cast from map III") { - if (!isTryCast) { - val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false) - val map = Literal.create(Map(1.0 -> 2.0), from) - val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) - val answer = Literal.create(Map(1 -> 2), to).value - checkEvaluation(cast(map, to), answer) - - Seq( - Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from), - Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap => - checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow") - } + val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false) + val map = Literal.create(Map(1.0 -> 2.0), from) + val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) + val answer = Literal.create(Map(1 -> 2), to).value + checkEvaluation(cast(map, to), answer) + + Seq( + Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from), + Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap => + checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow") } } @@ -487,26 +471,22 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved == !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[SparkRuntimeException]( - ret, - castErrMsg("123", BooleanType)) - } + assert(ret.resolved) + checkExceptionInExpression[SparkRuntimeException]( + ret, + castErrMsg("123", BooleanType)) } } test("cast from struct III") { - if (!isTryCast) { - val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false))) - val struct = Literal.create(InternalRow(1.0), from) - val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false))) - val answer = Literal.create(InternalRow(1), to).value - checkEvaluation(cast(struct, to), answer) + val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false))) + val struct = Literal.create(InternalRow(1.0), from) + val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false))) + val answer = Literal.create(InternalRow(1), to).value + checkEvaluation(cast(struct, to), answer) - val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from) - checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow") - } + val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from) + checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow") } test("complex casting") { @@ -533,12 +513,10 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === !isTryCast) - if (!isTryCast) { - checkExceptionInExpression[NumberFormatException]( - ret, - castErrMsg("true", IntegerType)) - } + assert(ret.resolved) + checkExceptionInExpression[NumberFormatException]( + ret, + castErrMsg("true", IntegerType)) } test("ANSI mode: cast string to timestamp with parse error") { @@ -599,28 +577,3 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase { } } } - -/** - * Test suite for data type casting expression [[Cast]] with ANSI mode disabled. - */ -class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase { - override def beforeAll(): Unit = { - super.beforeAll() - SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true) - } - - override def afterAll(): Unit = { - super.afterAll() - SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) - } - - override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { - v match { - case lit: Expression => Cast(lit, targetType, timeZoneId) - case _ => Cast(Literal(v), targetType, timeZoneId) - } - } - - override def setConfigurationHint: String = - s"set ${SQLConf.ANSI_ENABLED.key} as false" -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index bb9ab88894741..bb66a9fd24a96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -17,40 +17,65 @@ package org.apache.spark.sql.catalyst.expressions -import scala.reflect.ClassTag - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +// A test suite to check analysis behaviors of `TryCast`. +class TryCastSuite extends SparkFunSuite { -class TryCastSuite extends AnsiCastSuiteBase { - override protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String]) = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): TryCast = { v match { case lit: Expression => TryCast(lit, targetType, timeZoneId) case _ => TryCast(Literal(v), targetType, timeZoneId) } } - override def isTryCast: Boolean = true - - override protected def setConfigurationHint: String = "" - - override def checkExceptionInExpression[T <: Throwable : ClassTag]( - expression: => Expression, - inputRow: InternalRow, - expectedErrMsg: String): Unit = { - checkEvaluation(expression, null, inputRow) + test("print string") { + assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)") + assert(TryCast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)") } - override def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = { - checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value)) + test("nullability") { + assert(cast("abcdef", StringType).nullable) + assert(cast("abcdef", BinaryType).nullable) } - override def checkCastToNumericError(l: Literal, to: DataType, - expectedDataTypeInErrorMsg: DataType, tryCastResult: Any): Unit = { - checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value)) + test("only require timezone for datetime types") { + assert(cast("abc", IntegerType).resolved) + assert(!cast("abc", TimestampType).resolved) + assert(cast("abc", TimestampType, UTC_OPT).resolved) } - test("try_cast: to_string") { - assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)") + test("element type nullability") { + val array = Literal.create(Seq("123", "true"), + ArrayType(StringType, containsNull = false)) + // array element can be null after try_cast which violates the target type. + val c1 = cast(array, ArrayType(BooleanType, containsNull = false)) + assert(!c1.resolved) + + val map = Literal.create(Map("a" -> "123", "b" -> "true"), + MapType(StringType, StringType, valueContainsNull = false)) + // key can be null after try_cast which violates the map key requirement. + val c2 = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(!c2.resolved) + // map value can be null after try_cast which violates the target type. + val c3 = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(!c3.resolved) + + val struct = Literal.create( + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("true")), + new StructType() + .add("a", StringType, nullable = true) + .add("b", StringType, nullable = true)) + // struct field `b` can be null after try_cast which violates the target type. + val c4 = cast(struct, new StructType() + .add("a", BooleanType, nullable = true) + .add("b", BooleanType, nullable = false)) + assert(!c4.resolved) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index dc92cf24ab1e7..e1d1f064e34c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -146,7 +146,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Cast(Literal("2"), IntegerType) + Literal(3) + $"a" as "c1", - Coalesce(Seq(TryCast(Literal("abc"), IntegerType), Literal(3))) as "c2") + Coalesce(Seq(TryCast(Literal("abc"), IntegerType).replacement, Literal(3))) as "c2") val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 67bb72c187802..95e5582cb8c03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -1146,8 +1146,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // client-side filtering cannot be used with TimeZoneAwareExpression. def hasTimeZoneAwareExpression(e: Expression): Boolean = { e.exists { - case cast: CastBase => cast.needsTimeZone - case tz: TimeZoneAwareExpression => !tz.isInstanceOf[CastBase] + case cast: Cast => cast.needsTimeZone + case tz: TimeZoneAwareExpression => !tz.isInstanceOf[Cast] case _ => false } }