diff --git a/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala b/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala index 018341b5a90..60445145830 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala @@ -304,6 +304,30 @@ trait Checkpoints extends DeltaLogging { */ def checkpoint(): Unit = checkpoint(snapshot) + /** + * Catch non-fatal exceptions related to checkpointing, since the checkpoint is written + * after the commit has completed. From the perspective of the user, the commit has + * completed successfully. However, throw if this is in a testing environment - + * that way any breaking changes can be caught in unit tests. + */ + protected def withCheckpointExceptionHandling( + deltaLog: DeltaLog, opType: String)(thunk: => Unit): Unit = { + try { + thunk + } catch { + case NonFatal(e) => + recordDeltaEvent( + deltaLog, + opType, + data = Map("exception" -> e.getMessage(), "stackTrace" -> e.getStackTrace()) + ) + logWarning(s"Error when writing checkpoint-related files", e) + val throwError = Utils.isTesting || + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_CHECKPOINT_THROW_EXCEPTION_WHEN_FAILED) + if (throwError) throw e + } + } + /** * Creates a checkpoint using snapshotToCheckpoint. By default it uses the current log version. * Note that this function captures and logs all exceptions, since the checkpoint shouldn't fail @@ -311,48 +335,30 @@ trait Checkpoints extends DeltaLogging { */ def checkpoint(snapshotToCheckpoint: Snapshot): Unit = recordDeltaOperation( this, "delta.checkpoint") { - try { + withCheckpointExceptionHandling(snapshotToCheckpoint.deltaLog, "delta.checkpoint.sync.error") { if (snapshotToCheckpoint.version < 0) { throw DeltaErrors.checkpointNonExistTable(dataPath) } checkpointAndCleanUpDeltaLog(snapshotToCheckpoint) - } catch { - // Catch all non-fatal exceptions, since the checkpoint is written after the commit - // has completed. From the perspective of the user, the commit completed successfully. - // However, throw if this is in a testing environment - that way any breaking changes - // can be caught in unit tests. - case NonFatal(e) => - recordDeltaEvent( - snapshotToCheckpoint.deltaLog, - "delta.checkpoint.sync.error", - data = Map( - "exception" -> e.getMessage(), - "stackTrace" -> e.getStackTrace() - ) - ) - logWarning(s"Error when writing checkpoint synchronously", e) - val throwError = Utils.isTesting || - spark.sessionState.conf.getConf( - DeltaSQLConf.DELTA_CHECKPOINT_THROW_EXCEPTION_WHEN_FAILED) - if (throwError) { - throw e - } } } protected def checkpointAndCleanUpDeltaLog( snapshotToCheckpoint: Snapshot): Unit = { val checkpointMetaData = writeCheckpointFiles(snapshotToCheckpoint) - writeLastCheckpointFile(checkpointMetaData, CheckpointMetaData.checksumEnabled(spark)) + writeLastCheckpointFile( + snapshotToCheckpoint.deltaLog, checkpointMetaData, CheckpointMetaData.checksumEnabled(spark)) doLogCleanup(snapshotToCheckpoint) } - protected def writeLastCheckpointFile( + protected[delta] def writeLastCheckpointFile( + deltaLog: DeltaLog, checkpointMetaData: CheckpointMetaData, addChecksum: Boolean): Unit = { - val json = CheckpointMetaData.serializeToJson(checkpointMetaData, addChecksum) - store.write( - LAST_CHECKPOINT, Iterator(json), overwrite = true, newDeltaHadoopConf()) + withCheckpointExceptionHandling(deltaLog, "delta.lastCheckpoint.write.error") { + val json = CheckpointMetaData.serializeToJson(checkpointMetaData, addChecksum) + store.write(LAST_CHECKPOINT, Iterator(json), overwrite = true, newDeltaHadoopConf()) + } } protected def writeCheckpointFiles(