diff --git a/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala b/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala index a5d04155c9e..3fd14ca9876 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala @@ -94,14 +94,13 @@ case class PreprocessTableMerge(override val conf: SQLConf) val processedMatched = matched.map { case m: DeltaMergeIntoUpdateClause => - // Get any new columns which are in the insert clause, but not the target output or this - // update clause. + // Get any new columns which are in the update/insert clauses, but not the target output val existingColumns = m.resolvedActions.map(_.targetColNameParts.head) ++ target.output.map(_.name) - val newColumns = notMatched.toSeq.flatMap { - _.resolvedActions.filterNot { insertAct => + val newColumns = (matched ++ notMatched).toSeq.flatMap { + _.resolvedActions.filterNot { action => existingColumns.exists { colName => - conf.resolver(insertAct.targetColNameParts.head, colName) + conf.resolver(action.targetColNameParts.head, colName) } } } @@ -118,7 +117,7 @@ case class PreprocessTableMerge(override val conf: SQLConf) builder.result() } - val newColsFromInsert = distinctBy(newColumns)(_.targetColNameParts).map { action => + val newColumnsDistinct = distinctBy(newColumns)(_.targetColNameParts).map { action => AttributeReference(action.targetColNameParts.head, action.dataType)() } @@ -127,8 +126,8 @@ case class PreprocessTableMerge(override val conf: SQLConf) UpdateOperation(a.targetColNameParts, a.expr) } - // And construct operations for columns that the insert clause will add. - val newOpsFromInsert = newColsFromInsert.map { col => + // And construct operations for columns that the insert/update clauses will add. + val newUpdateOps = newColumnsDistinct.map { col => UpdateOperation(Seq(col.name), Literal(null, col.dataType)) } @@ -147,7 +146,7 @@ case class PreprocessTableMerge(override val conf: SQLConf) // that nested fields can be updated (only for existing columns). val alignedExprs = generateUpdateExpressions( finalSchemaExprs, - existingUpdateOps ++ newOpsFromInsert, + existingUpdateOps ++ newUpdateOps, conf.resolver, allowStructEvolution = migrateSchema, generatedColumns = generatedColumns) diff --git a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index b51bc981127..2075a8714e3 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -2569,6 +2569,19 @@ abstract class MergeIntoSuiteBase expectedWithoutEvolution = ((0, 0) +: (3, 30) +: (1, 1) +: Nil).toDF("key", "value") ) + testEvolution("new column with update set and update *")( + targetData = Seq((0, 0), (1, 10), (2, 20)).toDF("key", "value"), + sourceData = Seq((1, 1, "extra1"), (2, 2, "extra2")).toDF("key", "value", "extra"), + clauses = update(condition = "s.key < 2", set = "value = s.value") :: update("*") :: Nil, + expected = + ((0, 0, null) +: + (1, 1, null) +: // updated by first clause + (2, 2, "extra2") +: // updated by second clause + Nil + ).toDF("key", "value", "extra"), + expectedWithoutEvolution = ((0, 0) +: (1, 1) +: (2, 2) +: Nil).toDF("key", "value") + ) + testEvolution("update * with column not in source")( targetData = Seq((0, 0, 0), (1, 10, 10), (3, 30, 30)).toDF("key", "value", "extra"), sourceData = Seq((1, 1), (2, 2)).toDF("key", "value"),