diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala index 9d1c4ecad64..a5236f40c3f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DomainMetadataUtils.scala @@ -21,16 +21,21 @@ import org.apache.spark.sql.delta.actions.{Action, DomainMetadata, Protocol} import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain import org.apache.spark.sql.delta.metering.DeltaLogging -object DomainMetadataUtils extends DeltaLogging { +/** + * Domain metadata utility functions. + */ +trait DomainMetadataUtilsBase extends DeltaLogging { // List of metadata domains that will be removed for the REPLACE TABLE operation. - private val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set( - ) + protected val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set( + ClusteringMetadataDomain.domainName) + // List of metadata domains that will be copied from the table we are restoring to. - private val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE = - METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE + // Note that ClusteringMetadataDomain are recreated in handleDomainMetadataForRestoreTable + // instead of being blindly copied over. + protected val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE: Set[String] = Set.empty // List of metadata domains that will be copied from the table on a CLONE operation. - private val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set( + protected val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set( ClusteringMetadataDomain.domainName) /** @@ -96,7 +101,7 @@ object DomainMetadataUtils extends DeltaLogging { * "copy" list (e.g., table features require some specific domains to be copied). * - All other domains not in the list are dropped from the "toSnapshot". * - * For clustering metadata domains, it overwrites the existing domain metadata in the + * For clustering metadata domain, it overwrites the existing domain metadata in the * fromSnapshot with the following clustering columns. * 1. If toSnapshot is not a clustered table or missing domain metadata, use empty clustering * columns. @@ -125,7 +130,11 @@ object DomainMetadataUtils extends DeltaLogging { Seq.empty } - filteredDomainMetadata ++ Seq(ClusteredTableUtils.createDomainMetadata(clusteringColumns)) + val matchingMetadataDomain = + ClusteredTableUtils.getMatchingMetadataDomain( + clusteringColumns, + fromSnapshot.domainMetadata) + filteredDomainMetadata ++ matchingMetadataDomain.clusteringDomain } /** @@ -138,3 +147,5 @@ object DomainMetadataUtils extends DeltaLogging { } } } + +object DomainMetadataUtils extends DomainMetadataUtilsBase diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala index fe3e8b12633..33e68ab5d4a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableUtils.scala @@ -31,6 +31,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} +case class MatchingMetadataDomain( + clusteringDomain: Option[DomainMetadata] +) + /** * Clustered table utility functions. */ @@ -146,14 +150,33 @@ trait ClusteredTableUtilsBase extends DeltaLogging { txn.protocol, txn.metadata, clusterBy) val clusteringColumns = clusterBy.columnNames.map(_.toString).map(ClusteringColumn(txn.metadata.schema, _)) - Some(createDomainMetadata(clusteringColumns)).toSeq + Seq(createDomainMetadata(clusteringColumns)) }.getOrElse { - if (txn.snapshot.domainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) { - Some(createDomainMetadata(Seq.empty)).toSeq + getMatchingMetadataDomain( + Seq.empty, + txn.snapshot.domainMetadata).clusteringDomain.toSeq + } + } + + /** + * Returns a sequence of [[DomainMetadata]] actions to update the existing domain metadata with + * the given clustering columns. + * + * This is mainly used for REPLACE TABLE and RESTORE TABLE. + */ + def getMatchingMetadataDomain( + clusteringColumns: Seq[ClusteringColumn], + existingDomainMetadata: Seq[DomainMetadata]): MatchingMetadataDomain = { + val clusteringMetadataDomainOpt = + if (existingDomainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) { + Some(ClusteringMetadataDomain.fromClusteringColumns(clusteringColumns).toDomainMetadata) } else { - None.toSeq + None } - } + + MatchingMetadataDomain( + clusteringMetadataDomainOpt + ) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala index 36a3f0b548e..8fea8b49734 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala @@ -21,6 +21,7 @@ import java.io.File import com.databricks.spark.util.{Log4jUsageLogger, MetricDefinitions} import org.apache.spark.sql.delta.skipping.ClusteredTableTestUtils import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingEnableIdMode, DeltaColumnMappingEnableNameMode, DeltaConfigs, DeltaExcludedBySparkVersionTestMixinShims, DeltaLog, DeltaUnsupportedOperationException} +import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.SkippingEligibleDataType import org.apache.spark.sql.delta.test.{DeltaColumnMappingSelectedTestMixin, DeltaSQLCommandTest} @@ -864,13 +865,32 @@ trait ClusteredTableDDLSuiteBase verifyClusteringColumns(tableIdentifier, "a") } - // Scenario 4: restore to latest version. + // Scenario 4: restore to start version. withClusteredTable(testTable, "a int", "a") { verifyClusteringColumns(tableIdentifier, "a") + sql(s"INSERT INTO $testTable VALUES (1)") + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0") verifyClusteringColumns(tableIdentifier, "a") } + + // Scenario 5: restore unclustered table to unclustered table. + withTable(testTable) { + sql(s"CREATE TABLE $testTable (a INT) USING delta") + val (_, startingSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + assert(!ClusteredTableUtils.isSupported(startingSnapshot.protocol)) + assert(!startingSnapshot.domainMetadata.exists(_.domain == + ClusteringMetadataDomain.domainName)) + + sql(s"INSERT INTO $testTable VALUES (1)") + + sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0").collect + val (_, currentSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier) + assert(!ClusteredTableUtils.isSupported(currentSnapshot.protocol)) + assert(!currentSnapshot.domainMetadata.exists(_.domain == + ClusteringMetadataDomain.domainName)) + } } testSparkMasterOnly("Variant is not supported") {