diff --git a/spark/src/main/resources/error/delta-error-classes.json b/spark/src/main/resources/error/delta-error-classes.json index ab83161db6b..313ab4d9569 100644 --- a/spark/src/main/resources/error/delta-error-classes.json +++ b/spark/src/main/resources/error/delta-error-classes.json @@ -272,6 +272,14 @@ ], "sqlState" : "0A000" }, + "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE" : { + "message" : [ + "Failed to write a value of type into the type column due to an overflow.", + "Use `try_cast` on the input value to tolerate overflow and return NULL instead.", + "If necessary, set to \"LEGACY\" to bypass this error or set to true to revert to the old behaviour and follow in UPDATE and MERGE." + ], + "sqlState" : "22003" + }, "DELTA_CDC_NOT_ALLOWED_IN_THIS_VERSION" : { "message" : [ "Configuration delta.enableChangeDataFeed cannot be set. Change data feed from Delta is not yet available." diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala index 342dbf9b45a..f032aedff2e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -118,7 +119,8 @@ trait DocsPath { */ trait DeltaErrorsBase extends DocsPath - with DeltaLogging { + with DeltaLogging + with QueryErrorsBase { def baseDocsPath(spark: SparkSession): String = baseDocsPath(spark.sparkContext.getConf) @@ -618,6 +620,22 @@ trait DeltaErrorsBase ) } + def castingCauseOverflowErrorInTableWrite( + from: DataType, + to: DataType, + columnName: String): ArithmeticException = { + new DeltaArithmeticException( + errorClass = "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE", + messageParameters = Map( + "sourceType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "columnName" -> toSQLId(columnName), + "storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key, + "updateAndMergeCastingFollowsAnsiEnabledFlag" -> + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key, + "ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key)) + } + def notADeltaTable(table: String): Throwable = { new DeltaAnalysisException(errorClass = "DELTA_NOT_A_DELTA_TABLE", messageParameters = Array(table)) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala index 761ff228c81..5e06631c34a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala @@ -81,3 +81,11 @@ class DeltaParseException( ParserUtils.position(ctx.getStop) ) with DeltaThrowable +class DeltaArithmeticException( + errorClass: String, + messageParameters: Map[String, String]) extends ArithmeticException with DeltaThrowable { + override def getErrorClass: String = errorClass + + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava +} + diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala index e1718b6739c..501967b08a2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala @@ -21,10 +21,13 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.util.AnalysisHelper import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -405,7 +408,109 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper { } } + /** + * Replaces 'CastSupport.cast'. Selects a cast based on 'spark.sql.storeAssignmentPolicy' if + * 'spark.databricks.delta.updateAndMergeCastingFollowsAnsiEnabledFlag. is false, and based on + * 'spark.sql.ansi.enabled' otherwise. + */ private def cast(child: Expression, dataType: DataType, columnName: String): Expression = { - Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + if (conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) { + return Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } + + conf.storeAssignmentPolicy match { + case SQLConf.StoreAssignmentPolicy.LEGACY => + Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = false) + case SQLConf.StoreAssignmentPolicy.ANSI => + val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true) + if (canCauseCastOverflow(cast)) { + CheckOverflowInTableWrite(cast, columnName) + } else { + cast + } + case SQLConf.StoreAssignmentPolicy.STRICT => + UpCast(child, dataType) + } + } + + private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match { + case _: IntegralType | _: DecimalType => true + case a: ArrayType => containsIntegralOrDecimalType(a.elementType) + case m: MapType => + containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType) + case s: StructType => + s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType)) + case _ => false + } + + private def canCauseCastOverflow(cast: Cast): Boolean = { + containsIntegralOrDecimalType(cast.dataType) && + !Cast.canUpCast(cast.child.dataType, cast.dataType) + } +} + +case class CheckOverflowInTableWrite(child: Expression, columnName: String) + extends UnaryExpression { + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) } + + private def getCast: Option[Cast] = child match { + case c: Cast => Some(c) + case ExpressionProxy(c: Cast, _, _) => Some(c) + case _ => None + } + + override def eval(input: InternalRow): Any = try { + child.eval(input) + } catch { + case e: ArithmeticException => + getCast match { + case Some(cast) => + throw DeltaErrors.castingCauseOverflowErrorInTableWrite( + cast.child.dataType, + cast.dataType, + columnName) + case None => throw e + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + getCast match { + case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child) + case None => child.genCode(ctx) + } + } + + def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = { + val childGen = child.genCode(ctx) + val exceptionClass = classOf[ArithmeticException].getCanonicalName + assert(child.isInstanceOf[Cast]) + val cast = child.asInstanceOf[Cast] + val fromDt = + ctx.addReferenceObj("from", cast.child.dataType, cast.child.dataType.getClass.getName) + val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName) + val col = ctx.addReferenceObj("colName", columnName, "java.lang.String") + // scalastyle:off line.size.limit + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + try { + ${childGen.code} + ${ev.isNull} = ${childGen.isNull}; + ${ev.value} = ${childGen.value}; + } catch ($exceptionClass e) { + throw org.apache.spark.sql.delta.DeltaErrors + .castingCauseOverflowErrorInTableWrite($fromDt, $toDt, $col); + }""" + ) + // scalastyle:on line.size.limit + } + + override def dataType: DataType = child.dataType + + override def sql: String = child.sql + + override def toString: String = child.toString } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index 84c3211ef65..bd3c42fbc06 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -23,7 +23,6 @@ import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** * [[SQLConf]] entries for Delta features. @@ -1254,6 +1253,15 @@ trait DeltaSQLConfBase { .intConf .createWithDefault(100 * 1000) + val UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG = + buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag") + .internal() + .doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows + |'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This + |was the default before Delta 3.5.""".stripMargin) + .booleanConf + .createWithDefault(false) + } object DeltaSQLConf extends DeltaSQLConfBase diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala index 83594e18513..f0821527747 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{CalendarIntervalType, DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampNTZType} @@ -60,7 +61,8 @@ trait DeltaErrorsSuiteBase extends QueryTest with SharedSparkSession with GivenWhenThen with DeltaSQLCommandTest - with SQLTestUtils { + with SQLTestUtils + with QueryErrorsBase { val MAX_URL_ACCESS_RETRIES = 3 val path = "/sample/path" @@ -288,6 +290,24 @@ trait DeltaErrorsSuiteBase assert( e.getMessage == s"$table is a view. Writes to a view are not supported.") } + { + val sourceType = IntegerType + val targetType = DateType + val columnName = "column_name" + val e = intercept[DeltaArithmeticException] { + throw DeltaErrors.castingCauseOverflowErrorInTableWrite(sourceType, targetType, columnName) + } + assert(e.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE") + assert(e.getSqlState == "22003") + assert(e.getMessageParameters.get("sourceType") == toSQLType(sourceType)) + assert(e.getMessageParameters.get("targetType") == toSQLType(targetType)) + assert(e.getMessageParameters.get("columnName") == toSQLId(columnName)) + assert(e.getMessageParameters.get("storeAssignmentPolicyFlag") + == SQLConf.STORE_ASSIGNMENT_POLICY.key) + assert(e.getMessageParameters.get("updateAndMergeCastingFollowsAnsiEnabledFlag") + == DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key) + assert(e.getMessageParameters.get("ansiEnabledFlag") == SQLConf.ANSI_ENABLED.key) + } { val e = intercept[DeltaAnalysisException] { throw DeltaErrors.invalidColumnName(name = "col-1") diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala new file mode 100644 index 00000000000..30c67d97b1c --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala @@ -0,0 +1,284 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest + +import org.apache.spark.{SparkConf, SparkThrowable} +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.internal.SQLConf + +/** + * Tests for casts that are implicitly added in DML commands modifying Delta tables. + * These casts are added to convert values to the schema of a table. + * INSERT operations are excluded as they are covered by InsertSuite and InsertSuiteEdge. + */ +class ImplicitDMLCastingSuite extends QueryTest + with DeltaSQLCommandTest { + + private case class TestConfiguration( + sourceType: String, + sourceTypeInErrorMessage: String, + targetType: String, + targetTypeInErrorMessage: String, + validValue: String, + overflowValue: String) + + private case class SqlConfiguration( + followAnsiEnabled: Boolean, + ansiEnabled: Boolean, + storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value) { + + def withSqlSettings(f: => Unit): Unit = + withSQLConf( + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key + -> followAnsiEnabled.toString, + SQLConf.STORE_ASSIGNMENT_POLICY.key -> storeAssignmentPolicy.toString, + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString)(f) + + override def toString: String = + s"followAnsiEnabled: $followAnsiEnabled, ansiEnabled: $ansiEnabled," + + s" storeAssignmentPolicy: $storeAssignmentPolicy" + } + + private val testConfigurations = Seq( + TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT", + targetType = "TINYINT", targetTypeInErrorMessage = "TINYINT", + validValue = "1", overflowValue = Int.MaxValue.toString), + TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT", + targetType = "SMALLINT", targetTypeInErrorMessage = "SMALLINT", + validValue = "1", overflowValue = Int.MaxValue.toString), + TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT", + targetType = "INT", targetTypeInErrorMessage = "INT", + validValue = "1", overflowValue = Long.MaxValue.toString), + TestConfiguration(sourceType = "DOUBLE", sourceTypeInErrorMessage = "DOUBLE", + targetType = "BIGINT", targetTypeInErrorMessage = "BIGINT", + validValue = "1", overflowValue = "12345678901234567890D"), + TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT", + targetType = "DECIMAL(7,2)", targetTypeInErrorMessage = "DECIMAL(7,2)", + validValue = "1", overflowValue = Long.MaxValue.toString), + TestConfiguration(sourceType = "Struct", sourceTypeInErrorMessage = "BIGINT", + targetType = "Struct", targetTypeInErrorMessage = "INT", + validValue = "named_struct('value', 1)", + overflowValue = s"named_struct('value', ${Long.MaxValue.toString})"), + TestConfiguration(sourceType = "ARRAY", sourceTypeInErrorMessage = "ARRAY", + targetType = "ARRAY", targetTypeInErrorMessage = "ARRAY", + validValue = "ARRAY(1)", overflowValue = s"ARRAY(${Long.MaxValue.toString})") + ) + + @tailrec + private def arithmeticCause(exception: Throwable): Option[ArithmeticException] = { + exception match { + case arithmeticException: ArithmeticException => Some(arithmeticException) + case _ if exception.getCause != null => arithmeticCause(exception.getCause) + case _ => None + } + } + + /** + * Validate that a custom error is throws in case ansi.enabled is false, or a different + * overflow error is case ansi.enabled is true. + */ + private def validateException( + exception: Throwable, sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + arithmeticCause(exception) match { + case Some(exception: DeltaArithmeticException) => + assert(exception.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE") + assert(exception.getMessageParameters == + Map("sourceType" -> ("\"" + testConfig.sourceTypeInErrorMessage + "\""), + "targetType" -> ("\"" + testConfig.targetTypeInErrorMessage + "\""), + "columnName" -> "`value`", + "storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key, + "updateAndMergeCastingFollowsAnsiEnabledFlag" -> + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key, + "ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key).asJava) + case Some(exception: SparkThrowable) if sqlConfig.ansiEnabled => + // With ANSI enabled the overflows are caught before the write operation. + assert(Seq("CAST_OVERFLOW", "NUMERIC_VALUE_OUT_OF_RANGE") + .contains(exception.getErrorClass)) + case None => assert(false, "No arithmetic exception thrown.") + case Some(exception) => + assert(false, s"Unexpected exception type: $exception") + } + } + + Seq(true, false).foreach { followAnsiEnabled => + Seq(true, false).foreach { ansiEnabled => + Seq(SQLConf.StoreAssignmentPolicy.LEGACY, SQLConf.StoreAssignmentPolicy.ANSI) + .foreach { storeAssignmentPolicy => + val sqlConfiguration = + SqlConfiguration(followAnsiEnabled, ansiEnabled, storeAssignmentPolicy) + testConfigurations.foreach { testConfiguration => + updateTest(sqlConfiguration, testConfiguration) + mergeTests(sqlConfiguration, testConfiguration) + streamingMergeTest(sqlConfiguration, testConfiguration) + } + } + } + } + + /** Test an UPDATE that requires to cast the update value that is part of the SET clause. */ + private def updateTest( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + val testName = s"UPDATE overflow targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val tableName = "overflowTable" + withTable(tableName) { + sql(s"""CREATE TABLE $tableName USING DELTA + |AS SELECT cast(${testConfig.validValue} AS ${testConfig.targetType}) AS value + |""".stripMargin) + val updateCommand = s"UPDATE $tableName SET value = ${testConfig.overflowValue}" + + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + sql(updateCommand) + } else { + val exception = intercept[Throwable] { + sql(updateCommand) + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } + + + /** Tests for MERGE with overflows cause by the different conditions. */ + private def mergeTests( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + mergeTest(matchedCondition = s"WHEN MATCHED THEN UPDATE SET t.value = s.value", + sqlConfig, testConfig) + + mergeTest(matchedCondition = s"WHEN NOT MATCHED THEN INSERT *", sqlConfig, testConfig) + + mergeTest(matchedCondition = + s"WHEN NOT MATCHED BY SOURCE THEN UPDATE SET t.value = ${testConfig.overflowValue}", + sqlConfig, testConfig) + } + + private def mergeTest( + matchedCondition: String, + sqlConfig: SqlConfiguration, + testConfig: TestConfiguration + ): Unit = { + val testName = + s"MERGE overflow in $matchedCondition targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val targetTableName = "target_table" + val sourceViewName = "source_vice" + withTable(targetTableName) { + withTempView(sourceViewName) { + val numRows = 10 + sql(s"""CREATE TABLE $targetTableName USING DELTA + |AS SELECT col as key, + | cast(${testConfig.validValue} AS ${testConfig.targetType}) AS value + |FROM explode(sequence(0, $numRows))""".stripMargin) + // The view maps the key space such that we get matched, not matched by source, and + // not match by target rows. + sql(s"""CREATE TEMPORARY VIEW $sourceViewName + |AS SELECT key + ($numRows / 2) AS key, + | cast(${testConfig.overflowValue} AS ${testConfig.sourceType}) AS value + |FROM $targetTableName""".stripMargin) + val mergeCommand = s"""MERGE INTO $targetTableName t + |USING $sourceViewName s + |ON s.key = t.key + |$matchedCondition + |""".stripMargin + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + sql(mergeCommand) + } else { + val exception = intercept[Throwable] { + sql(mergeCommand) + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } + } + + /** A merge that is executed for each batch of a stream and has to cast values before insert. */ + private def streamingMergeTest( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + val testName = s"Streaming MERGE overflow targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val targetTableName = "target_table" + val sourceTableName = "source_table" + withTable(sourceTableName, targetTableName) { + sql(s"CREATE TABLE $targetTableName (key INT, value ${testConfig.targetType})" + + " USING DELTA") + sql(s"CREATE TABLE $sourceTableName (key INT, value ${testConfig.sourceType})" + + " USING DELTA") + + def upsertToDelta(microBatchOutputDF: DataFrame, batchId: Long): Unit = { + microBatchOutputDF.createOrReplaceTempView("micro_batch_output") + + microBatchOutputDF.sparkSession.sql(s"""MERGE INTO $targetTableName t + |USING micro_batch_output s + |ON s.key = t.key + |WHEN NOT MATCHED THEN INSERT * + |""".stripMargin) + } + + val sourceStream = spark.readStream.table(sourceTableName) + val streamWriter = + sourceStream + .writeStream + .format("delta") + .foreachBatch(upsertToDelta _) + .outputMode("update") + .start() + + sql(s"INSERT INTO $sourceTableName(key, value) VALUES(0, ${testConfig.overflowValue})") + + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + streamWriter.processAllAvailable() + } else { + val exception = intercept[Throwable] { + streamWriter.processAllAvailable() + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } +} + diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index 313baa24920..6920b00c094 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -3871,7 +3871,7 @@ abstract class MergeIntoSuiteBase ((0, 0) +: (1, 10) +: (2, 2) +: (3, 30) +: (5, null) +: Nil) .asInstanceOf[List[(Integer, Integer)]].toDF("key", "value"), // Disable ANSI as this test needs to cast string "notANumber" to int - confs = Seq(SQLConf.ANSI_ENABLED.key -> "false") + confs = Seq(SQLConf.STORE_ASSIGNMENT_POLICY.key -> "LEGACY") ) // This is kinda bug-for-bug compatibility. It doesn't really make sense that infinity is casted @@ -3887,7 +3887,7 @@ abstract class MergeIntoSuiteBase ((0, 0) +: (1, 10) +: (2, 2) +: (3, 30) +: (5, Int.MaxValue) +: Nil) .asInstanceOf[List[(Integer, Integer)]].toDF("key", "value"), // Disable ANSI as this test needs to cast Double.PositiveInfinity to int - confs = Seq(SQLConf.ANSI_ENABLED.key -> "false") + confs = Seq(SQLConf.STORE_ASSIGNMENT_POLICY.key -> "LEGACY") ) testEvolution("extra nested column in source - insert")(