diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala index 94df58755fb..04dd9a0b96c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala @@ -16,20 +16,26 @@ package org.apache.spark.sql.delta +import org.apache.spark.sql.delta.skipping.clustering.ClusteredTableUtils import org.apache.spark.sql.delta.actions.{Action, DomainMetadata, Protocol} import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain import org.apache.spark.sql.delta.metering.DeltaLogging -object DomainMetadataUtils extends DeltaLogging { +/** + * Domain metadata utility functions. + */ +trait DomainMetadataUtilsBase extends DeltaLogging { // List of metadata domains that will be removed for the REPLACE TABLE operation. - private val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set( - ) + protected val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set( + ClusteringMetadataDomain.domainName) + // List of metadata domains that will be copied from the table we are restoring to. - private val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE = - METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE + // Note that ClusteringMetadataDomain are recreated in handleDomainMetadataForRestoreTable + // instead of being blindly copied over. + protected val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE: Set[String] = Set.empty // List of metadata domains that will be copied from the table on a CLONE operation. - private val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set( + protected val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set( ClusteringMetadataDomain.domainName) /** @@ -91,15 +97,44 @@ object DomainMetadataUtils extends DeltaLogging { /** * Generates a new sequence of DomainMetadata to commits for RESTORE TABLE. - * - Source (table to restore to) domains will be copied if they appear in the pre-defined + * - Domains in the toSnapshot will be copied if they appear in the pre-defined * "copy" list (e.g., table features require some specific domains to be copied). - * - All other domains not in the list are "retained". + * - All other domains not in the list are dropped from the "toSnapshot". + * + * For clustering metadata domain, it overwrites the existing domain metadata in the + * fromSnapshot with the following clustering columns. + * 1. If toSnapshot is not a clustered table or missing domain metadata, use empty clustering + * columns. + * 2. If toSnapshot is a clustered table, use the clustering columns from toSnapshot. + * + * @param toSnapshot The snapshot being restored to, which is referred as "source" table. + * @param fromSnapshot The snapshot being restored from, which is the current state. */ def handleDomainMetadataForRestoreTable( - sourceDomainMetadatas: Seq[DomainMetadata]): Seq[DomainMetadata] = { - sourceDomainMetadatas.filter { m => + toSnapshot: Snapshot, + fromSnapshot: Snapshot): Seq[DomainMetadata] = { + val filteredDomainMetadata = toSnapshot.domainMetadata.filter { m => METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE.contains(m.domain) } + val clusteringColumnsToRestore = ClusteredTableUtils.getClusteringColumnsOptional(toSnapshot) + + val isRestoringToClusteredTable = + ClusteredTableUtils.isSupported(toSnapshot.protocol) && clusteringColumnsToRestore.nonEmpty + val clusteringColumns = if (isRestoringToClusteredTable) { + // We overwrite the clustering columns in the fromSnapshot with the clustering columns + // in the toSnapshot. + clusteringColumnsToRestore.get + } else { + // toSnapshot is not a clustered table or missing domain metadata, so we write domain + // metadata with empty clustering columns. + Seq.empty + } + + val matchingMetadataDomain = + ClusteredTableUtils.getMatchingMetadataDomain( + clusteringColumns, + fromSnapshot.domainMetadata) + filteredDomainMetadata ++ matchingMetadataDomain.clusteringDomainOpt } /** @@ -112,3 +147,5 @@ object DomainMetadataUtils extends DeltaLogging { } } } + +object DomainMetadataUtils extends DomainMetadataUtilsBase diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala index ddb466f5bd4..0271c7150fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import scala.collection.JavaConverters._ import scala.util.{Success, Try} -import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, Snapshot} +import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DomainMetadataUtils, Snapshot} import org.apache.spark.sql.delta.actions.{AddFile, DeletionVectorDescriptor, RemoveFile} import org.apache.spark.sql.delta.catalog.DeltaTableV2 import org.apache.spark.sql.delta.sources.DeltaSQLConf @@ -205,9 +205,12 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2) sourceProtocol.merge(targetProtocol) } + val actions = addActions ++ removeActions ++ + DomainMetadataUtils.handleDomainMetadataForRestoreTable(snapshotToRestore, latestSnapshot) + txn.commitLarge( spark, - addActions ++ removeActions, + actions, Some(newProtocol), DeltaOperations.Restore(version, timestamp), Map.empty, diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala index fe3e8b12633..baf255dd5db 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala @@ -31,6 +31,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} +case class MatchingMetadataDomain( + clusteringDomainOpt: Option[DomainMetadata] +) + /** * Clustered table utility functions. */ @@ -146,14 +150,33 @@ trait ClusteredTableUtilsBase extends DeltaLogging { txn.protocol, txn.metadata, clusterBy) val clusteringColumns = clusterBy.columnNames.map(_.toString).map(ClusteringColumn(txn.metadata.schema, _)) - Some(createDomainMetadata(clusteringColumns)).toSeq + Seq(createDomainMetadata(clusteringColumns)) }.getOrElse { - if (txn.snapshot.domainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) { - Some(createDomainMetadata(Seq.empty)).toSeq + getMatchingMetadataDomain( + clusteringColumns = Seq.empty, + txn.snapshot.domainMetadata).clusteringDomainOpt.toSeq + } + } + + /** + * Returns a sequence of [[DomainMetadata]] actions to update the existing domain metadata with + * the given clustering columns. + * + * This is mainly used for REPLACE TABLE and RESTORE TABLE. + */ + def getMatchingMetadataDomain( + clusteringColumns: Seq[ClusteringColumn], + existingDomainMetadata: Seq[DomainMetadata]): MatchingMetadataDomain = { + val clusteringMetadataDomainOpt = + if (existingDomainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) { + Some(ClusteringMetadataDomain.fromClusteringColumns(clusteringColumns).toDomainMetadata) } else { - None.toSeq + None } - } + + MatchingMetadataDomain( + clusteringMetadataDomainOpt + ) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/ClusteredTableTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/ClusteredTableTestUtils.scala index 1a8a7274914..4e9bebbd020 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/ClusteredTableTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/ClusteredTableTestUtils.scala @@ -152,7 +152,8 @@ trait ClusteredTableTestUtilsBase extends SparkFunSuite with SharedSparkSession } else { assertClusterByNotExist() } - case "WRITE" => + case "WRITE" | "RESTORE" => + // These are known operations from our tests that don't have clusterBy. doAssert(!lastOperationParameters.contains(CLUSTERING_PARAMETER_KEY)) case _ => // Other operations are not tested yet. If the test fails here, please check the expected diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala index 86ae5f3b451..8fea8b49734 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala @@ -21,6 +21,7 @@ import java.io.File import com.databricks.spark.util.{Log4jUsageLogger, MetricDefinitions} import org.apache.spark.sql.delta.skipping.ClusteredTableTestUtils import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingEnableIdMode, DeltaColumnMappingEnableNameMode, DeltaConfigs, DeltaExcludedBySparkVersionTestMixinShims, DeltaLog, DeltaUnsupportedOperationException} +import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.SkippingEligibleDataType import org.apache.spark.sql.delta.test.{DeltaColumnMappingSelectedTestMixin, DeltaSQLCommandTest} @@ -480,7 +481,8 @@ trait ClusteredTableDDLWithColumnMapping "validate dropping clustering column is not allowed: single clustering column", "validate dropping clustering column is not allowed: multiple clustering columns", "validate dropping clustering column is not allowed: clustering column + " + - "non-clustering column" + "non-clustering column", + "validate RESTORE on clustered table" ) test("validate dropping clustering column is not allowed: single clustering column") { @@ -825,6 +827,72 @@ trait ClusteredTableDDLSuiteBase } } + test("validate RESTORE on clustered table") { + val tableIdentifier = TableIdentifier(testTable) + // Scenario 1: restore clustered table to unclustered version. + withTable(testTable) { + sql(s"CREATE TABLE $testTable (a INT, b STRING) USING delta") + val (_, startingSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + assert(!ClusteredTableUtils.isSupported(startingSnapshot.protocol)) + + sql(s"ALTER TABLE $testTable CLUSTER BY (a)") + verifyClusteringColumns(tableIdentifier, "a") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0") + val (_, currentSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + verifyClusteringColumns(tableIdentifier, "") + } + + // Scenario 2: restore clustered table to previous clustering columns. + withClusteredTable(testTable, "a INT, b STRING", "a") { + verifyClusteringColumns(tableIdentifier, "a") + + sql(s"ALTER TABLE $testTable CLUSTER BY (b)") + verifyClusteringColumns(tableIdentifier, "b") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0") + verifyClusteringColumns(tableIdentifier, "a") + } + + // Scenario 3: restore from table with clustering columns to non-empty clustering columns + withClusteredTable(testTable, "a int", "a") { + verifyClusteringColumns(tableIdentifier, "a") + + sql(s"ALTER TABLE $testTable CLUSTER BY NONE") + verifyClusteringColumns(tableIdentifier, "") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0") + verifyClusteringColumns(tableIdentifier, "a") + } + + // Scenario 4: restore to start version. + withClusteredTable(testTable, "a int", "a") { + verifyClusteringColumns(tableIdentifier, "a") + + sql(s"INSERT INTO $testTable VALUES (1)") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0") + verifyClusteringColumns(tableIdentifier, "a") + } + + // Scenario 5: restore unclustered table to unclustered table. + withTable(testTable) { + sql(s"CREATE TABLE $testTable (a INT) USING delta") + val (_, startingSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + assert(!ClusteredTableUtils.isSupported(startingSnapshot.protocol)) + assert(!startingSnapshot.domainMetadata.exists(_.domain == + ClusteringMetadataDomain.domainName)) + + sql(s"INSERT INTO $testTable VALUES (1)") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0").collect + val (_, currentSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + assert(!ClusteredTableUtils.isSupported(currentSnapshot.protocol)) + assert(!currentSnapshot.domainMetadata.exists(_.domain == + ClusteringMetadataDomain.domainName)) + } + } + testSparkMasterOnly("Variant is not supported") { val e = intercept[DeltaAnalysisException] { createOrReplaceClusteredTable("CREATE", testTable, "id long, v variant", "v")