Skip to content

Commit

Permalink
[Spark] Support RESTORE for clustered table (#3194)
Browse files Browse the repository at this point in the history
## Description
Support RESTORE for clustered tables by adding a new domain metadata to
overwrite the existing one so that clustering columns are correctly
restored.

## How was this patch tested?
New unit tests.
  • Loading branch information
zedtang authored Jun 7, 2024
1 parent 87549c5 commit ef1def9
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,26 @@

package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.skipping.clustering.ClusteredTableUtils
import org.apache.spark.sql.delta.actions.{Action, DomainMetadata, Protocol}
import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain
import org.apache.spark.sql.delta.metering.DeltaLogging

object DomainMetadataUtils extends DeltaLogging {
/**
* Domain metadata utility functions.
*/
trait DomainMetadataUtilsBase extends DeltaLogging {
// List of metadata domains that will be removed for the REPLACE TABLE operation.
private val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set(
)
protected val METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE: Set[String] = Set(
ClusteringMetadataDomain.domainName)

// List of metadata domains that will be copied from the table we are restoring to.
private val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE =
METADATA_DOMAINS_TO_REMOVE_FOR_REPLACE_TABLE
// Note that ClusteringMetadataDomain are recreated in handleDomainMetadataForRestoreTable
// instead of being blindly copied over.
protected val METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE: Set[String] = Set.empty

// List of metadata domains that will be copied from the table on a CLONE operation.
private val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set(
protected val METADATA_DOMAIN_TO_COPY_FOR_CLONE_TABLE: Set[String] = Set(
ClusteringMetadataDomain.domainName)

/**
Expand Down Expand Up @@ -91,15 +97,44 @@ object DomainMetadataUtils extends DeltaLogging {

/**
* Generates a new sequence of DomainMetadata to commits for RESTORE TABLE.
* - Source (table to restore to) domains will be copied if they appear in the pre-defined
* - Domains in the toSnapshot will be copied if they appear in the pre-defined
* "copy" list (e.g., table features require some specific domains to be copied).
* - All other domains not in the list are "retained".
* - All other domains not in the list are dropped from the "toSnapshot".
*
* For clustering metadata domain, it overwrites the existing domain metadata in the
* fromSnapshot with the following clustering columns.
* 1. If toSnapshot is not a clustered table or missing domain metadata, use empty clustering
* columns.
* 2. If toSnapshot is a clustered table, use the clustering columns from toSnapshot.
*
* @param toSnapshot The snapshot being restored to, which is referred as "source" table.
* @param fromSnapshot The snapshot being restored from, which is the current state.
*/
def handleDomainMetadataForRestoreTable(
sourceDomainMetadatas: Seq[DomainMetadata]): Seq[DomainMetadata] = {
sourceDomainMetadatas.filter { m =>
toSnapshot: Snapshot,
fromSnapshot: Snapshot): Seq[DomainMetadata] = {
val filteredDomainMetadata = toSnapshot.domainMetadata.filter { m =>
METADATA_DOMAIN_TO_COPY_FOR_RESTORE_TABLE.contains(m.domain)
}
val clusteringColumnsToRestore = ClusteredTableUtils.getClusteringColumnsOptional(toSnapshot)

val isRestoringToClusteredTable =
ClusteredTableUtils.isSupported(toSnapshot.protocol) && clusteringColumnsToRestore.nonEmpty
val clusteringColumns = if (isRestoringToClusteredTable) {
// We overwrite the clustering columns in the fromSnapshot with the clustering columns
// in the toSnapshot.
clusteringColumnsToRestore.get
} else {
// toSnapshot is not a clustered table or missing domain metadata, so we write domain
// metadata with empty clustering columns.
Seq.empty
}

val matchingMetadataDomain =
ClusteredTableUtils.getMatchingMetadataDomain(
clusteringColumns,
fromSnapshot.domainMetadata)
filteredDomainMetadata ++ matchingMetadataDomain.clusteringDomainOpt
}

/**
Expand All @@ -112,3 +147,5 @@ object DomainMetadataUtils extends DeltaLogging {
}
}
}

object DomainMetadataUtils extends DomainMetadataUtilsBase
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Timestamp
import scala.collection.JavaConverters._
import scala.util.{Success, Try}

import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, Snapshot}
import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DomainMetadataUtils, Snapshot}
import org.apache.spark.sql.delta.actions.{AddFile, DeletionVectorDescriptor, RemoveFile}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down Expand Up @@ -205,9 +205,12 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2)
sourceProtocol.merge(targetProtocol)
}

val actions = addActions ++ removeActions ++
DomainMetadataUtils.handleDomainMetadataForRestoreTable(snapshotToRestore, latestSnapshot)

txn.commitLarge(
spark,
addActions ++ removeActions,
actions,
Some(newProtocol),
DeltaOperations.Restore(version, timestamp),
Map.empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}

case class MatchingMetadataDomain(
clusteringDomainOpt: Option[DomainMetadata]
)

/**
* Clustered table utility functions.
*/
Expand Down Expand Up @@ -146,14 +150,33 @@ trait ClusteredTableUtilsBase extends DeltaLogging {
txn.protocol, txn.metadata, clusterBy)
val clusteringColumns =
clusterBy.columnNames.map(_.toString).map(ClusteringColumn(txn.metadata.schema, _))
Some(createDomainMetadata(clusteringColumns)).toSeq
Seq(createDomainMetadata(clusteringColumns))
}.getOrElse {
if (txn.snapshot.domainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) {
Some(createDomainMetadata(Seq.empty)).toSeq
getMatchingMetadataDomain(
clusteringColumns = Seq.empty,
txn.snapshot.domainMetadata).clusteringDomainOpt.toSeq
}
}

/**
* Returns a sequence of [[DomainMetadata]] actions to update the existing domain metadata with
* the given clustering columns.
*
* This is mainly used for REPLACE TABLE and RESTORE TABLE.
*/
def getMatchingMetadataDomain(
clusteringColumns: Seq[ClusteringColumn],
existingDomainMetadata: Seq[DomainMetadata]): MatchingMetadataDomain = {
val clusteringMetadataDomainOpt =
if (existingDomainMetadata.exists(_.domain == ClusteringMetadataDomain.domainName)) {
Some(ClusteringMetadataDomain.fromClusteringColumns(clusteringColumns).toDomainMetadata)
} else {
None.toSeq
None
}
}

MatchingMetadataDomain(
clusteringMetadataDomainOpt
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ trait ClusteredTableTestUtilsBase extends SparkFunSuite with SharedSparkSession
} else {
assertClusterByNotExist()
}
case "WRITE" =>
case "WRITE" | "RESTORE" =>
// These are known operations from our tests that don't have clusterBy.
doAssert(!lastOperationParameters.contains(CLUSTERING_PARAMETER_KEY))
case _ =>
// Other operations are not tested yet. If the test fails here, please check the expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.File
import com.databricks.spark.util.{Log4jUsageLogger, MetricDefinitions}
import org.apache.spark.sql.delta.skipping.ClusteredTableTestUtils
import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingEnableIdMode, DeltaColumnMappingEnableNameMode, DeltaConfigs, DeltaExcludedBySparkVersionTestMixinShims, DeltaLog, DeltaUnsupportedOperationException}
import org.apache.spark.sql.delta.clustering.ClusteringMetadataDomain
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.stats.SkippingEligibleDataType
import org.apache.spark.sql.delta.test.{DeltaColumnMappingSelectedTestMixin, DeltaSQLCommandTest}
Expand Down Expand Up @@ -480,7 +481,8 @@ trait ClusteredTableDDLWithColumnMapping
"validate dropping clustering column is not allowed: single clustering column",
"validate dropping clustering column is not allowed: multiple clustering columns",
"validate dropping clustering column is not allowed: clustering column + " +
"non-clustering column"
"non-clustering column",
"validate RESTORE on clustered table"
)

test("validate dropping clustering column is not allowed: single clustering column") {
Expand Down Expand Up @@ -825,6 +827,72 @@ trait ClusteredTableDDLSuiteBase
}
}

test("validate RESTORE on clustered table") {
val tableIdentifier = TableIdentifier(testTable)
// Scenario 1: restore clustered table to unclustered version.
withTable(testTable) {
sql(s"CREATE TABLE $testTable (a INT, b STRING) USING delta")
val (_, startingSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier)
assert(!ClusteredTableUtils.isSupported(startingSnapshot.protocol))

sql(s"ALTER TABLE $testTable CLUSTER BY (a)")
verifyClusteringColumns(tableIdentifier, "a")

sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0")
val (_, currentSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier)
verifyClusteringColumns(tableIdentifier, "")
}

// Scenario 2: restore clustered table to previous clustering columns.
withClusteredTable(testTable, "a INT, b STRING", "a") {
verifyClusteringColumns(tableIdentifier, "a")

sql(s"ALTER TABLE $testTable CLUSTER BY (b)")
verifyClusteringColumns(tableIdentifier, "b")

sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0")
verifyClusteringColumns(tableIdentifier, "a")
}

// Scenario 3: restore from table with clustering columns to non-empty clustering columns
withClusteredTable(testTable, "a int", "a") {
verifyClusteringColumns(tableIdentifier, "a")

sql(s"ALTER TABLE $testTable CLUSTER BY NONE")
verifyClusteringColumns(tableIdentifier, "")

sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0")
verifyClusteringColumns(tableIdentifier, "a")
}

// Scenario 4: restore to start version.
withClusteredTable(testTable, "a int", "a") {
verifyClusteringColumns(tableIdentifier, "a")

sql(s"INSERT INTO $testTable VALUES (1)")

sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0")
verifyClusteringColumns(tableIdentifier, "a")
}

// Scenario 5: restore unclustered table to unclustered table.
withTable(testTable) {
sql(s"CREATE TABLE $testTable (a INT) USING delta")
val (_, startingSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier)
assert(!ClusteredTableUtils.isSupported(startingSnapshot.protocol))
assert(!startingSnapshot.domainMetadata.exists(_.domain ==
ClusteringMetadataDomain.domainName))

sql(s"INSERT INTO $testTable VALUES (1)")

sql(s"RESTORE TABLE $testTable TO VERSION AS OF 0").collect
val (_, currentSnapshot) = DeltaLog.forTableWithSnapshot(spark, tableIdentifier)
assert(!ClusteredTableUtils.isSupported(currentSnapshot.protocol))
assert(!currentSnapshot.domainMetadata.exists(_.domain ==
ClusteringMetadataDomain.domainName))
}
}

testSparkMasterOnly("Variant is not supported") {
val e = intercept[DeltaAnalysisException] {
createOrReplaceClusteredTable("CREATE", testTable, "id long, v variant", "v")
Expand Down

0 comments on commit ef1def9

Please # to comment.