From 6d78d43470fb14bba264c5107d1f07b3beaacec4 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 26 Jul 2023 17:30:06 +0200 Subject: [PATCH] Use storageAssighmentPolicy for casts in DML commands Follow spark.sql.storeAssignmentPolicy instead of spark.sql.ansi.enabled for casting behaviour in UPDATE and MERGE. This will by default error out at runtime when an overflow happens. Closes https://github.com/delta-io/delta/pull/1938 GitOrigin-RevId: c960a0521df27daa6ee231e0a1022d8756496785 --- .../resources/error/delta-error-classes.json | 8 + .../apache/spark/sql/delta/DeltaErrors.scala | 20 +- .../sql/delta/DeltaSharedExceptions.scala | 8 + .../sql/delta/UpdateExpressionsSupport.scala | 109 ++++++- .../sql/delta/sources/DeltaSQLConf.scala | 10 +- .../spark/sql/delta/DeltaErrorsSuite.scala | 22 +- .../sql/delta/ImplicitDMLCastingSuite.scala | 284 ++++++++++++++++++ .../spark/sql/delta/MergeIntoSuiteBase.scala | 4 +- 8 files changed, 458 insertions(+), 7 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala diff --git a/spark/src/main/resources/error/delta-error-classes.json b/spark/src/main/resources/error/delta-error-classes.json index ab83161db6b..313ab4d9569 100644 --- a/spark/src/main/resources/error/delta-error-classes.json +++ b/spark/src/main/resources/error/delta-error-classes.json @@ -272,6 +272,14 @@ ], "sqlState" : "0A000" }, + "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE" : { + "message" : [ + "Failed to write a value of type into the type column due to an overflow.", + "Use `try_cast` on the input value to tolerate overflow and return NULL instead.", + "If necessary, set to \"LEGACY\" to bypass this error or set to true to revert to the old behaviour and follow in UPDATE and MERGE." + ], + "sqlState" : "22003" + }, "DELTA_CDC_NOT_ALLOWED_IN_THIS_VERSION" : { "message" : [ "Configuration delta.enableChangeDataFeed cannot be set. Change data feed from Delta is not yet available." diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala index 342dbf9b45a..f032aedff2e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -118,7 +119,8 @@ trait DocsPath { */ trait DeltaErrorsBase extends DocsPath - with DeltaLogging { + with DeltaLogging + with QueryErrorsBase { def baseDocsPath(spark: SparkSession): String = baseDocsPath(spark.sparkContext.getConf) @@ -618,6 +620,22 @@ trait DeltaErrorsBase ) } + def castingCauseOverflowErrorInTableWrite( + from: DataType, + to: DataType, + columnName: String): ArithmeticException = { + new DeltaArithmeticException( + errorClass = "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE", + messageParameters = Map( + "sourceType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "columnName" -> toSQLId(columnName), + "storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key, + "updateAndMergeCastingFollowsAnsiEnabledFlag" -> + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key, + "ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key)) + } + def notADeltaTable(table: String): Throwable = { new DeltaAnalysisException(errorClass = "DELTA_NOT_A_DELTA_TABLE", messageParameters = Array(table)) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala index 761ff228c81..5e06631c34a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala @@ -81,3 +81,11 @@ class DeltaParseException( ParserUtils.position(ctx.getStop) ) with DeltaThrowable +class DeltaArithmeticException( + errorClass: String, + messageParameters: Map[String, String]) extends ArithmeticException with DeltaThrowable { + override def getErrorClass: String = errorClass + + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava +} + diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala index e1718b6739c..501967b08a2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala @@ -21,10 +21,13 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.util.AnalysisHelper import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -405,7 +408,109 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper { } } + /** + * Replaces 'CastSupport.cast'. Selects a cast based on 'spark.sql.storeAssignmentPolicy' if + * 'spark.databricks.delta.updateAndMergeCastingFollowsAnsiEnabledFlag. is false, and based on + * 'spark.sql.ansi.enabled' otherwise. + */ private def cast(child: Expression, dataType: DataType, columnName: String): Expression = { - Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + if (conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) { + return Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } + + conf.storeAssignmentPolicy match { + case SQLConf.StoreAssignmentPolicy.LEGACY => + Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = false) + case SQLConf.StoreAssignmentPolicy.ANSI => + val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true) + if (canCauseCastOverflow(cast)) { + CheckOverflowInTableWrite(cast, columnName) + } else { + cast + } + case SQLConf.StoreAssignmentPolicy.STRICT => + UpCast(child, dataType) + } + } + + private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match { + case _: IntegralType | _: DecimalType => true + case a: ArrayType => containsIntegralOrDecimalType(a.elementType) + case m: MapType => + containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType) + case s: StructType => + s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType)) + case _ => false + } + + private def canCauseCastOverflow(cast: Cast): Boolean = { + containsIntegralOrDecimalType(cast.dataType) && + !Cast.canUpCast(cast.child.dataType, cast.dataType) + } +} + +case class CheckOverflowInTableWrite(child: Expression, columnName: String) + extends UnaryExpression { + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) } + + private def getCast: Option[Cast] = child match { + case c: Cast => Some(c) + case ExpressionProxy(c: Cast, _, _) => Some(c) + case _ => None + } + + override def eval(input: InternalRow): Any = try { + child.eval(input) + } catch { + case e: ArithmeticException => + getCast match { + case Some(cast) => + throw DeltaErrors.castingCauseOverflowErrorInTableWrite( + cast.child.dataType, + cast.dataType, + columnName) + case None => throw e + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + getCast match { + case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child) + case None => child.genCode(ctx) + } + } + + def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = { + val childGen = child.genCode(ctx) + val exceptionClass = classOf[ArithmeticException].getCanonicalName + assert(child.isInstanceOf[Cast]) + val cast = child.asInstanceOf[Cast] + val fromDt = + ctx.addReferenceObj("from", cast.child.dataType, cast.child.dataType.getClass.getName) + val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName) + val col = ctx.addReferenceObj("colName", columnName, "java.lang.String") + // scalastyle:off line.size.limit + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + try { + ${childGen.code} + ${ev.isNull} = ${childGen.isNull}; + ${ev.value} = ${childGen.value}; + } catch ($exceptionClass e) { + throw org.apache.spark.sql.delta.DeltaErrors + .castingCauseOverflowErrorInTableWrite($fromDt, $toDt, $col); + }""" + ) + // scalastyle:on line.size.limit + } + + override def dataType: DataType = child.dataType + + override def sql: String = child.sql + + override def toString: String = child.toString } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index 84c3211ef65..bd3c42fbc06 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -23,7 +23,6 @@ import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** * [[SQLConf]] entries for Delta features. @@ -1254,6 +1253,15 @@ trait DeltaSQLConfBase { .intConf .createWithDefault(100 * 1000) + val UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG = + buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag") + .internal() + .doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows + |'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This + |was the default before Delta 3.5.""".stripMargin) + .booleanConf + .createWithDefault(false) + } object DeltaSQLConf extends DeltaSQLConfBase diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala index 83594e18513..f0821527747 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{CalendarIntervalType, DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampNTZType} @@ -60,7 +61,8 @@ trait DeltaErrorsSuiteBase extends QueryTest with SharedSparkSession with GivenWhenThen with DeltaSQLCommandTest - with SQLTestUtils { + with SQLTestUtils + with QueryErrorsBase { val MAX_URL_ACCESS_RETRIES = 3 val path = "/sample/path" @@ -288,6 +290,24 @@ trait DeltaErrorsSuiteBase assert( e.getMessage == s"$table is a view. Writes to a view are not supported.") } + { + val sourceType = IntegerType + val targetType = DateType + val columnName = "column_name" + val e = intercept[DeltaArithmeticException] { + throw DeltaErrors.castingCauseOverflowErrorInTableWrite(sourceType, targetType, columnName) + } + assert(e.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE") + assert(e.getSqlState == "22003") + assert(e.getMessageParameters.get("sourceType") == toSQLType(sourceType)) + assert(e.getMessageParameters.get("targetType") == toSQLType(targetType)) + assert(e.getMessageParameters.get("columnName") == toSQLId(columnName)) + assert(e.getMessageParameters.get("storeAssignmentPolicyFlag") + == SQLConf.STORE_ASSIGNMENT_POLICY.key) + assert(e.getMessageParameters.get("updateAndMergeCastingFollowsAnsiEnabledFlag") + == DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key) + assert(e.getMessageParameters.get("ansiEnabledFlag") == SQLConf.ANSI_ENABLED.key) + } { val e = intercept[DeltaAnalysisException] { throw DeltaErrors.invalidColumnName(name = "col-1") diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala new file mode 100644 index 00000000000..30c67d97b1c --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/ImplicitDMLCastingSuite.scala @@ -0,0 +1,284 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest + +import org.apache.spark.{SparkConf, SparkThrowable} +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.internal.SQLConf + +/** + * Tests for casts that are implicitly added in DML commands modifying Delta tables. + * These casts are added to convert values to the schema of a table. + * INSERT operations are excluded as they are covered by InsertSuite and InsertSuiteEdge. + */ +class ImplicitDMLCastingSuite extends QueryTest + with DeltaSQLCommandTest { + + private case class TestConfiguration( + sourceType: String, + sourceTypeInErrorMessage: String, + targetType: String, + targetTypeInErrorMessage: String, + validValue: String, + overflowValue: String) + + private case class SqlConfiguration( + followAnsiEnabled: Boolean, + ansiEnabled: Boolean, + storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value) { + + def withSqlSettings(f: => Unit): Unit = + withSQLConf( + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key + -> followAnsiEnabled.toString, + SQLConf.STORE_ASSIGNMENT_POLICY.key -> storeAssignmentPolicy.toString, + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString)(f) + + override def toString: String = + s"followAnsiEnabled: $followAnsiEnabled, ansiEnabled: $ansiEnabled," + + s" storeAssignmentPolicy: $storeAssignmentPolicy" + } + + private val testConfigurations = Seq( + TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT", + targetType = "TINYINT", targetTypeInErrorMessage = "TINYINT", + validValue = "1", overflowValue = Int.MaxValue.toString), + TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT", + targetType = "SMALLINT", targetTypeInErrorMessage = "SMALLINT", + validValue = "1", overflowValue = Int.MaxValue.toString), + TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT", + targetType = "INT", targetTypeInErrorMessage = "INT", + validValue = "1", overflowValue = Long.MaxValue.toString), + TestConfiguration(sourceType = "DOUBLE", sourceTypeInErrorMessage = "DOUBLE", + targetType = "BIGINT", targetTypeInErrorMessage = "BIGINT", + validValue = "1", overflowValue = "12345678901234567890D"), + TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT", + targetType = "DECIMAL(7,2)", targetTypeInErrorMessage = "DECIMAL(7,2)", + validValue = "1", overflowValue = Long.MaxValue.toString), + TestConfiguration(sourceType = "Struct", sourceTypeInErrorMessage = "BIGINT", + targetType = "Struct", targetTypeInErrorMessage = "INT", + validValue = "named_struct('value', 1)", + overflowValue = s"named_struct('value', ${Long.MaxValue.toString})"), + TestConfiguration(sourceType = "ARRAY", sourceTypeInErrorMessage = "ARRAY", + targetType = "ARRAY", targetTypeInErrorMessage = "ARRAY", + validValue = "ARRAY(1)", overflowValue = s"ARRAY(${Long.MaxValue.toString})") + ) + + @tailrec + private def arithmeticCause(exception: Throwable): Option[ArithmeticException] = { + exception match { + case arithmeticException: ArithmeticException => Some(arithmeticException) + case _ if exception.getCause != null => arithmeticCause(exception.getCause) + case _ => None + } + } + + /** + * Validate that a custom error is throws in case ansi.enabled is false, or a different + * overflow error is case ansi.enabled is true. + */ + private def validateException( + exception: Throwable, sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + arithmeticCause(exception) match { + case Some(exception: DeltaArithmeticException) => + assert(exception.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE") + assert(exception.getMessageParameters == + Map("sourceType" -> ("\"" + testConfig.sourceTypeInErrorMessage + "\""), + "targetType" -> ("\"" + testConfig.targetTypeInErrorMessage + "\""), + "columnName" -> "`value`", + "storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key, + "updateAndMergeCastingFollowsAnsiEnabledFlag" -> + DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key, + "ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key).asJava) + case Some(exception: SparkThrowable) if sqlConfig.ansiEnabled => + // With ANSI enabled the overflows are caught before the write operation. + assert(Seq("CAST_OVERFLOW", "NUMERIC_VALUE_OUT_OF_RANGE") + .contains(exception.getErrorClass)) + case None => assert(false, "No arithmetic exception thrown.") + case Some(exception) => + assert(false, s"Unexpected exception type: $exception") + } + } + + Seq(true, false).foreach { followAnsiEnabled => + Seq(true, false).foreach { ansiEnabled => + Seq(SQLConf.StoreAssignmentPolicy.LEGACY, SQLConf.StoreAssignmentPolicy.ANSI) + .foreach { storeAssignmentPolicy => + val sqlConfiguration = + SqlConfiguration(followAnsiEnabled, ansiEnabled, storeAssignmentPolicy) + testConfigurations.foreach { testConfiguration => + updateTest(sqlConfiguration, testConfiguration) + mergeTests(sqlConfiguration, testConfiguration) + streamingMergeTest(sqlConfiguration, testConfiguration) + } + } + } + } + + /** Test an UPDATE that requires to cast the update value that is part of the SET clause. */ + private def updateTest( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + val testName = s"UPDATE overflow targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val tableName = "overflowTable" + withTable(tableName) { + sql(s"""CREATE TABLE $tableName USING DELTA + |AS SELECT cast(${testConfig.validValue} AS ${testConfig.targetType}) AS value + |""".stripMargin) + val updateCommand = s"UPDATE $tableName SET value = ${testConfig.overflowValue}" + + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + sql(updateCommand) + } else { + val exception = intercept[Throwable] { + sql(updateCommand) + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } + + + /** Tests for MERGE with overflows cause by the different conditions. */ + private def mergeTests( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + mergeTest(matchedCondition = s"WHEN MATCHED THEN UPDATE SET t.value = s.value", + sqlConfig, testConfig) + + mergeTest(matchedCondition = s"WHEN NOT MATCHED THEN INSERT *", sqlConfig, testConfig) + + mergeTest(matchedCondition = + s"WHEN NOT MATCHED BY SOURCE THEN UPDATE SET t.value = ${testConfig.overflowValue}", + sqlConfig, testConfig) + } + + private def mergeTest( + matchedCondition: String, + sqlConfig: SqlConfiguration, + testConfig: TestConfiguration + ): Unit = { + val testName = + s"MERGE overflow in $matchedCondition targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val targetTableName = "target_table" + val sourceViewName = "source_vice" + withTable(targetTableName) { + withTempView(sourceViewName) { + val numRows = 10 + sql(s"""CREATE TABLE $targetTableName USING DELTA + |AS SELECT col as key, + | cast(${testConfig.validValue} AS ${testConfig.targetType}) AS value + |FROM explode(sequence(0, $numRows))""".stripMargin) + // The view maps the key space such that we get matched, not matched by source, and + // not match by target rows. + sql(s"""CREATE TEMPORARY VIEW $sourceViewName + |AS SELECT key + ($numRows / 2) AS key, + | cast(${testConfig.overflowValue} AS ${testConfig.sourceType}) AS value + |FROM $targetTableName""".stripMargin) + val mergeCommand = s"""MERGE INTO $targetTableName t + |USING $sourceViewName s + |ON s.key = t.key + |$matchedCondition + |""".stripMargin + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + sql(mergeCommand) + } else { + val exception = intercept[Throwable] { + sql(mergeCommand) + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } + } + + /** A merge that is executed for each batch of a stream and has to cast values before insert. */ + private def streamingMergeTest( + sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = { + val testName = s"Streaming MERGE overflow targetType: ${testConfig.targetType} $sqlConfig" + test(testName) { + sqlConfig.withSqlSettings { + val targetTableName = "target_table" + val sourceTableName = "source_table" + withTable(sourceTableName, targetTableName) { + sql(s"CREATE TABLE $targetTableName (key INT, value ${testConfig.targetType})" + + " USING DELTA") + sql(s"CREATE TABLE $sourceTableName (key INT, value ${testConfig.sourceType})" + + " USING DELTA") + + def upsertToDelta(microBatchOutputDF: DataFrame, batchId: Long): Unit = { + microBatchOutputDF.createOrReplaceTempView("micro_batch_output") + + microBatchOutputDF.sparkSession.sql(s"""MERGE INTO $targetTableName t + |USING micro_batch_output s + |ON s.key = t.key + |WHEN NOT MATCHED THEN INSERT * + |""".stripMargin) + } + + val sourceStream = spark.readStream.table(sourceTableName) + val streamWriter = + sourceStream + .writeStream + .format("delta") + .foreachBatch(upsertToDelta _) + .outputMode("update") + .start() + + sql(s"INSERT INTO $sourceTableName(key, value) VALUES(0, ${testConfig.overflowValue})") + + val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) || + (!sqlConfig.followAnsiEnabled && + sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY) + + if (legacyCasts) { + streamWriter.processAllAvailable() + } else { + val exception = intercept[Throwable] { + streamWriter.processAllAvailable() + } + + validateException(exception, sqlConfig, testConfig) + } + } + } + } + } +} + diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index 313baa24920..6920b00c094 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -3871,7 +3871,7 @@ abstract class MergeIntoSuiteBase ((0, 0) +: (1, 10) +: (2, 2) +: (3, 30) +: (5, null) +: Nil) .asInstanceOf[List[(Integer, Integer)]].toDF("key", "value"), // Disable ANSI as this test needs to cast string "notANumber" to int - confs = Seq(SQLConf.ANSI_ENABLED.key -> "false") + confs = Seq(SQLConf.STORE_ASSIGNMENT_POLICY.key -> "LEGACY") ) // This is kinda bug-for-bug compatibility. It doesn't really make sense that infinity is casted @@ -3887,7 +3887,7 @@ abstract class MergeIntoSuiteBase ((0, 0) +: (1, 10) +: (2, 2) +: (3, 30) +: (5, Int.MaxValue) +: Nil) .asInstanceOf[List[(Integer, Integer)]].toDF("key", "value"), // Disable ANSI as this test needs to cast Double.PositiveInfinity to int - confs = Seq(SQLConf.ANSI_ENABLED.key -> "false") + confs = Seq(SQLConf.STORE_ASSIGNMENT_POLICY.key -> "LEGACY") ) testEvolution("extra nested column in source - insert")(