Skip to content

Commit 8b419e7

Browse files
[Spark] Add Managed Commit support in getSnapshotAt API (delta-io#2879)
## Description Add Managed Commit support in deltaLog.getSnapshotAt() API. ## How was this patch tested? UTs ## Does this PR introduce _any_ user-facing changes? No
1 parent ad9f67a commit 8b419e7

File tree

4 files changed

+208
-90
lines changed

4 files changed

+208
-90
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,7 @@ trait SnapshotManagement { self: DeltaLog =>
463463
throw new IllegalStateException(s"Could not find any delta files for version $newVersion")
464464
}
465465
if (versionToLoad.exists(_ != newVersion)) {
466-
throw new IllegalStateException(
467-
s"Trying to load a non-existent version ${versionToLoad.get}")
466+
throwNonExistentVersionError(versionToLoad.get)
468467
}
469468
val lastCommitTimestamp = deltas.last.getModificationTime
470469

@@ -558,6 +557,11 @@ trait SnapshotManagement { self: DeltaLog =>
558557
deltasAfterCheckpoint
559558
}
560559

560+
def throwNonExistentVersionError(versionToLoad: Long): Unit = {
561+
throw new IllegalStateException(
562+
s"Trying to load a non-existent version $versionToLoad")
563+
}
564+
561565
/**
562566
* Load the Snapshot for this Delta table at initialization. This method uses the `lastCheckpoint`
563567
* file as a hint on where to start listing the transaction log directory. If the _delta_log
@@ -1143,66 +1147,64 @@ trait SnapshotManagement { self: DeltaLog =>
11431147
}
11441148
}
11451149

1150+
/** Get the snapshot at `version`. */
1151+
def getSnapshotAt(
1152+
version: Long,
1153+
lastCheckpointHint: Option[CheckpointInstance] = None): Snapshot = {
1154+
getSnapshotAt(version, lastCheckpointHint, lastCheckpointProvider = None)
1155+
}
1156+
11461157
/**
1147-
* Get the snapshot at `version` using the given `lastCheckpointProvider` hint
1158+
* Get the snapshot at `version` using the given `lastCheckpointProvider` or `lastCheckpointHint`
11481159
* as the listing hint.
11491160
*/
11501161
private[delta] def getSnapshotAt(
11511162
version: Long,
1152-
lastCheckpointProvider: CheckpointProvider): Snapshot = {
1163+
lastCheckpointHint: Option[CheckpointInstance],
1164+
lastCheckpointProvider: Option[CheckpointProvider]): Snapshot = {
1165+
11531166
// See if the version currently cached on the cluster satisfies the requirement
1154-
val current = unsafeVolatileSnapshot
1155-
if (current.version == version) {
1156-
return current
1167+
val currentSnapshot = unsafeVolatileSnapshot
1168+
val upperBoundSnapshot = if (currentSnapshot.version >= version) {
1169+
// current snapshot is already newer than what we are looking for. so it could be used as
1170+
// upper bound.
1171+
currentSnapshot
1172+
} else {
1173+
val latestSnapshot = update()
1174+
if (latestSnapshot.version < version) {
1175+
throwNonExistentVersionError(version)
1176+
}
1177+
latestSnapshot
11571178
}
1158-
if (lastCheckpointProvider.version > version) {
1159-
// if the provided lastCheckpointProvider's version is greater than the snapshot that we are
1160-
// trying to create => we can't use the provider.
1161-
// fallback to the other overload.
1162-
return getSnapshotAt(version)
1179+
if (upperBoundSnapshot.version == version) {
1180+
return upperBoundSnapshot
11631181
}
1164-
val segment = createLogSegment(
1165-
versionToLoad = Some(version),
1166-
oldCheckpointProviderOpt = Some(lastCheckpointProvider)
1167-
).getOrElse {
1168-
// We can't return InitialSnapshot because our caller asked for a specific snapshot version.
1169-
throw DeltaErrors.emptyDirectoryException(logPath.toString)
1170-
}
1171-
createSnapshot(
1172-
initSegment = segment,
1173-
commitStoreOpt = None,
1174-
checksumOpt = None)
1175-
}
11761182

1177-
/** Get the snapshot at `version`. */
1178-
def getSnapshotAt(
1179-
version: Long,
1180-
lastCheckpointHint: Option[CheckpointInstance] = None): Snapshot = {
1181-
// See if the version currently cached on the cluster satisfies the requirement
1182-
val current = unsafeVolatileSnapshot
1183-
if (current.version == version) {
1184-
return current
1183+
val (lastCheckpointInfoOpt, lastCheckpointProviderOpt) = lastCheckpointProvider match {
1184+
// NOTE: We must ignore any hint whose version is higher than the requested version.
1185+
case Some(checkpointProvider) if checkpointProvider.version <= version =>
1186+
// Prefer the last checkpoint provider hint, because it doesn't require any I/O to use.
1187+
None -> Some(checkpointProvider)
1188+
case _ =>
1189+
val lastCheckpointInfoForListing = lastCheckpointHint
1190+
.filter(_.version <= version)
1191+
.orElse(findLastCompleteCheckpointBefore(version))
1192+
.map(manuallyLoadCheckpoint)
1193+
lastCheckpointInfoForListing -> None
11851194
}
1186-
1187-
// Do not use the hint if the version we're asking for is smaller than the last checkpoint hint
1188-
val lastCheckpointInfoHint =
1189-
lastCheckpointHint
1190-
.collect { case ci if ci.version <= version => ci }
1191-
.orElse(findLastCompleteCheckpointBefore(version))
1192-
.map(manuallyLoadCheckpoint)
1193-
createLogSegment(
1195+
val logSegmentOpt = createLogSegment(
11941196
versionToLoad = Some(version),
1195-
lastCheckpointInfo = lastCheckpointInfoHint,
1196-
commitStoreOpt = current.commitStoreOpt
1197-
).map { segment =>
1198-
createSnapshot(
1199-
initSegment = segment,
1200-
commitStoreOpt = None,
1201-
checksumOpt = None)
1202-
}.getOrElse {
1197+
oldCheckpointProviderOpt = lastCheckpointProviderOpt,
1198+
commitStoreOpt = upperBoundSnapshot.commitStoreOpt,
1199+
lastCheckpointInfo = lastCheckpointInfoOpt)
1200+
val logSegment = logSegmentOpt.getOrElse {
12031201
// We can't return InitialSnapshot because our caller asked for a specific snapshot version.
12041202
throw DeltaErrors.emptyDirectoryException(logPath.toString)
12051203
}
1204+
createSnapshot(
1205+
initSegment = logSegment,
1206+
commitStoreOpt = upperBoundSnapshot.commitStoreOpt,
1207+
checksumOpt = None)
12061208
}
12071209

12081210
// Visible for testing

spark/src/main/scala/org/apache/spark/sql/delta/hooks/CheckpointHook.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ object CheckpointHook extends PostCommitHook {
3636
// Since the postCommitSnapshot isn't guaranteed to match committedVersion, we have to
3737
// explicitly checkpoint the snapshot at the committedVersion.
3838
val cp = postCommitSnapshot.checkpointProvider
39-
txn.deltaLog.checkpoint(txn.deltaLog.getSnapshotAt(committedVersion, cp)
40-
)
39+
val snapshotToCheckpoint = txn.deltaLog.getSnapshotAt(
40+
committedVersion,
41+
lastCheckpointHint = None,
42+
lastCheckpointProvider = Some(cp))
43+
txn.deltaLog.checkpoint(snapshotToCheckpoint)
4144
}
4245
}

spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala

Lines changed: 140 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import org.apache.spark.sql.delta.DeltaConfigs.{MANAGED_COMMIT_OWNER_CONF, MANAG
2323
import org.apache.spark.sql.delta.DeltaLog
2424
import org.apache.spark.sql.delta.DeltaTestUtils.createTestAddFile
2525
import org.apache.spark.sql.delta.InitialSnapshot
26+
import org.apache.spark.sql.delta.Snapshot
2627
import org.apache.spark.sql.delta.actions.{Action, Metadata}
2728
import org.apache.spark.sql.delta.sources.DeltaSQLConf
2829
import org.apache.spark.sql.delta.storage.LogStore
2930
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
3031
import org.apache.spark.sql.delta.test.DeltaSQLTestUtils
3132
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
3233
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
34+
import org.apache.spark.sql.delta.util.FileNames.{CompactedDeltaFile, DeltaFile}
3335
import org.apache.hadoop.conf.Configuration
3436
import org.apache.hadoop.fs.{FileStatus, Path}
3537

@@ -145,50 +147,41 @@ class ManagedCommitSuite
145147
}
146148

147149
// Test commit owner changed on concurrent cluster
148-
test("snapshot is updated recursively when FS table is converted to commit owner" +
149-
" table on a concurrent cluster") {
150+
testWithoutManagedCommits("snapshot is updated recursively when FS table is converted to commit" +
151+
" owner table on a concurrent cluster") {
150152
val commitStore = new TrackingCommitStore(new InMemoryCommitStore(batchSize = 10))
151153
val builder = TrackingInMemoryCommitStoreBuilder(batchSize = 10, Some(commitStore))
152154
CommitStoreProvider.registerBuilder(builder)
153-
val oldCommitOwnerValue = spark.conf.get(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey)
154-
spark.conf.unset(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey)
155155

156-
try {
157-
withTempDir { tempDir =>
158-
val tablePath = tempDir.getAbsolutePath
159-
val deltaLog1 = DeltaLog.forTable(spark, tablePath)
160-
deltaLog1.startTransaction().commitManually(Metadata())
161-
deltaLog1.startTransaction().commitManually(createTestAddFile("f1"))
162-
deltaLog1.startTransaction().commitManually()
163-
val snapshotV2 = deltaLog1.update()
164-
assert(snapshotV2.version === 2)
165-
assert(snapshotV2.commitStoreOpt.isEmpty)
166-
DeltaLog.clearCache()
156+
withTempDir { tempDir =>
157+
val tablePath = tempDir.getAbsolutePath
158+
val deltaLog1 = DeltaLog.forTable(spark, tablePath)
159+
deltaLog1.startTransaction().commitManually(Metadata())
160+
deltaLog1.startTransaction().commitManually(createTestAddFile("f1"))
161+
deltaLog1.startTransaction().commitManually()
162+
val snapshotV2 = deltaLog1.update()
163+
assert(snapshotV2.version === 2)
164+
assert(snapshotV2.commitStoreOpt.isEmpty)
165+
DeltaLog.clearCache()
167166

168-
// Add new commit to convert FS table to managed-commit table
169-
val deltaLog2 = DeltaLog.forTable(spark, tablePath)
170-
val oldMetadata = snapshotV2.metadata
171-
val commitOwner = (MANAGED_COMMIT_OWNER_NAME.key -> "tracking-in-memory")
172-
val newMetadata = oldMetadata.copy(configuration = oldMetadata.configuration + commitOwner)
173-
deltaLog2.startTransaction().commitManually(newMetadata)
174-
commitStore.registerTable(deltaLog2.logPath, 3)
175-
deltaLog2.startTransaction().commitManually(createTestAddFile("f2"))
176-
deltaLog2.startTransaction().commitManually()
177-
val snapshotV5 = deltaLog2.unsafeVolatileSnapshot
178-
assert(snapshotV5.version === 5)
179-
assert(snapshotV5.commitStoreOpt.nonEmpty)
180-
// only delta 4/5 will be un-backfilled and should have two dots in filename (x.uuid.json)
181-
assert(snapshotV5.logSegment.deltas.count(_.getPath.getName.count(_ == '.') == 2) === 2)
182-
183-
val usageRecords = Log4jUsageLogger.track {
184-
val newSnapshotV5 = deltaLog1.update()
185-
assert(newSnapshotV5.version === 5)
186-
assert(newSnapshotV5.logSegment.deltas === snapshotV5.logSegment.deltas)
187-
}
188-
assert(filterUsageRecords(usageRecords, "delta.readChecksum").size === 2)
167+
// Add new commit to convert FS table to managed-commit table
168+
val deltaLog2 = DeltaLog.forTable(spark, tablePath)
169+
enableManagedCommit(deltaLog2, commitOwner = "tracking-in-memory")
170+
commitStore.registerTable(deltaLog2.logPath, 3)
171+
deltaLog2.startTransaction().commitManually(createTestAddFile("f2"))
172+
deltaLog2.startTransaction().commitManually()
173+
val snapshotV5 = deltaLog2.unsafeVolatileSnapshot
174+
assert(snapshotV5.version === 5)
175+
assert(snapshotV5.commitStoreOpt.nonEmpty)
176+
// only delta 4/5 will be un-backfilled and should have two dots in filename (x.uuid.json)
177+
assert(snapshotV5.logSegment.deltas.count(_.getPath.getName.count(_ == '.') == 2) === 2)
178+
179+
val usageRecords = Log4jUsageLogger.track {
180+
val newSnapshotV5 = deltaLog1.update()
181+
assert(newSnapshotV5.version === 5)
182+
assert(newSnapshotV5.logSegment.deltas === snapshotV5.logSegment.deltas)
189183
}
190-
} finally {
191-
spark.conf.set(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey, oldCommitOwnerValue)
184+
assert(filterUsageRecords(usageRecords, "delta.readChecksum").size === 2)
192185
}
193186
}
194187

@@ -517,7 +510,7 @@ class ManagedCommitSuite
517510
}
518511
}
519512

520-
testWithDifferentBackfillInterval("ensure backfills commit files works as expected") { _ =>
513+
testWithDifferentBackfillInterval("Snapshot.ensureCommitFilesBackfilled") { _ =>
521514
withTempDir { tempDir =>
522515
val tablePath = tempDir.getAbsolutePath
523516

@@ -534,6 +527,113 @@ class ManagedCommitSuite
534527
val backfilledCommitFiles = (0 to 9).map(
535528
version => FileNames.unsafeDeltaFile(log.logPath, version))
536529
assert(commitFiles.toSeq == backfilledCommitFiles)
537-
}
530+
}
531+
}
532+
533+
testWithoutManagedCommits("DeltaLog.getSnapshotAt") {
534+
val commitStore = new TrackingCommitStore(new InMemoryCommitStore(batchSize = 10))
535+
val builder = TrackingInMemoryCommitStoreBuilder(batchSize = 10, Some(commitStore))
536+
CommitStoreProvider.registerBuilder(builder)
537+
def checkGetSnapshotAt(
538+
deltaLog: DeltaLog,
539+
version: Long,
540+
expectedUpdateCount: Int,
541+
expectedListingCount: Int): Snapshot = {
542+
var snapshot: Snapshot = null
543+
544+
val usageRecords = Log4jUsageLogger.track {
545+
snapshot = deltaLog.getSnapshotAt(version)
546+
assert(snapshot.version === version)
547+
}
548+
assert(filterUsageRecords(usageRecords, "deltaLog.update").size === expectedUpdateCount)
549+
// deltaLog.update() will internally do listing
550+
assert(filterUsageRecords(usageRecords, "delta.deltaLog.listDeltaAndCheckpointFiles").size
551+
=== expectedListingCount)
552+
val versionsInLogSegment = if (version < 6) {
553+
snapshot.logSegment.deltas.map(FileNames.deltaVersion(_))
554+
} else {
555+
snapshot.logSegment.deltas.flatMap {
556+
case DeltaFile(_, deltaVersion) => Seq(deltaVersion)
557+
case CompactedDeltaFile(_, startVersion, endVersion) => (startVersion to endVersion)
558+
}
559+
}
560+
assert(versionsInLogSegment === (0L to version))
561+
snapshot
562+
}
563+
564+
withTempDir { dir =>
565+
val tablePath = dir.getAbsolutePath
566+
// Part-1: Validate getSnapshotAt API works as expected for non-managed commit tables
567+
// commit 0, 1, 2 on FS table
568+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v0
569+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v1
570+
val deltaLog1 = DeltaLog.forTable(spark, tablePath)
571+
DeltaLog.clearCache()
572+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v2
573+
assert(deltaLog1.unsafeVolatileSnapshot.version === 1)
574+
575+
checkGetSnapshotAt(deltaLog1, version = 1, expectedUpdateCount = 0, expectedListingCount = 0)
576+
// deltaLog1 still points to version 1. So, we will do listing to get v0.
577+
checkGetSnapshotAt(deltaLog1, version = 0, expectedUpdateCount = 0, expectedListingCount = 1)
578+
// deltaLog1 still points to version 1 although we are asking for v2 So we do a
579+
// deltaLog.update - the update will internally do listing.Since the updated snapshot is same
580+
// as what we want, so we won't create another snapshot and do another listing.
581+
checkGetSnapshotAt(deltaLog1, version = 2, expectedUpdateCount = 1, expectedListingCount = 1)
582+
var deltaLog2 = DeltaLog.forTable(spark, tablePath)
583+
Seq(deltaLog1, deltaLog2).foreach { log => assert(log.unsafeVolatileSnapshot.version === 2) }
584+
DeltaLog.clearCache()
585+
586+
// Part-2: Validate getSnapshotAt API works as expected for managed commit tables when the
587+
// switch is made
588+
// commit 3
589+
enableManagedCommit(DeltaLog.forTable(spark, tablePath), "tracking-in-memory")
590+
commitStore.registerTable(deltaLog1.logPath, maxCommitVersion = 3)
591+
// commit 4
592+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath)
593+
// the old deltaLog objects still points to version 2
594+
Seq(deltaLog1, deltaLog2).foreach { log => assert(log.unsafeVolatileSnapshot.version === 2) }
595+
// deltaLog1 points to version 2. So, we will do listing to get v1. Snapshot update not
596+
// needed as what we are looking for is less than what deltaLog1 points to.
597+
checkGetSnapshotAt(deltaLog1, version = 1, expectedUpdateCount = 0, expectedListingCount = 1)
598+
// deltaLog1.unsafeVolatileSnapshot.version points to v2 - return it directly.
599+
checkGetSnapshotAt(deltaLog1, version = 2, expectedUpdateCount = 0, expectedListingCount = 0)
600+
// We are asking for v3 although the deltaLog1.unsafeVolatileSnapshot is for v2. So this will
601+
// need deltaLog.update() to get the latest snapshot first - this update itself internally
602+
// will do 2 round of listing as we are discovering a commit store after first round of
603+
// listing. Once the update finishes, deltaLog1 will point to v4. So we need another round of
604+
// listing to get just v3.
605+
checkGetSnapshotAt(deltaLog1, version = 3, expectedUpdateCount = 1, expectedListingCount = 3)
606+
// Ask for v3 again - this time deltaLog1.unsafeVolatileSnapshot points to v4.
607+
// So we don't need deltaLog.update as version which we are asking is less than pinned
608+
// version. Just do listing and get the snapshot.
609+
checkGetSnapshotAt(deltaLog1, version = 3, expectedUpdateCount = 0, expectedListingCount = 1)
610+
// deltaLog1.unsafeVolatileSnapshot.version points to v4 - return it directly.
611+
checkGetSnapshotAt(deltaLog1, version = 4, expectedUpdateCount = 0, expectedListingCount = 0)
612+
// We are asking for v3 although the deltaLog2.unsafeVolatileSnapshot is for v2. So this will
613+
// need deltaLog.update() to get the latest snapshot first - this update itself internally
614+
// will do 2 round of listing as we are discovering a commit store after first round of
615+
// listing. Once the update finishes, deltaLog2 will point to v4. It can be returned directly.
616+
checkGetSnapshotAt(deltaLog2, version = 4, expectedUpdateCount = 1, expectedListingCount = 2)
617+
618+
// Part-2: Validate getSnapshotAt API works as expected for managed commit tables
619+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v5
620+
deltaLog2 = DeltaLog.forTable(spark, tablePath)
621+
DeltaLog.clearCache()
622+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v6
623+
Seq(1).toDF.write.format("delta").mode("overwrite").save(tablePath) // v7
624+
assert(deltaLog2.unsafeVolatileSnapshot.version === 5)
625+
checkGetSnapshotAt(deltaLog2, version = 1, expectedUpdateCount = 0, expectedListingCount = 1)
626+
checkGetSnapshotAt(deltaLog2, version = 2, expectedUpdateCount = 0, expectedListingCount = 1)
627+
checkGetSnapshotAt(deltaLog2, version = 4, expectedUpdateCount = 0, expectedListingCount = 1)
628+
checkGetSnapshotAt(deltaLog2, version = 5, expectedUpdateCount = 0, expectedListingCount = 0)
629+
checkGetSnapshotAt(deltaLog2, version = 6, expectedUpdateCount = 1, expectedListingCount = 2)
630+
}
631+
}
632+
633+
private def enableManagedCommit(deltaLog: DeltaLog, commitOwner: String): Unit = {
634+
val oldMetadata = deltaLog.update().metadata
635+
val commitOwnerConf = (MANAGED_COMMIT_OWNER_NAME.key -> commitOwner)
636+
val newMetadata = oldMetadata.copy(configuration = oldMetadata.configuration + commitOwnerConf)
637+
deltaLog.startTransaction().commitManually(newMetadata)
538638
}
539639
}

spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.spark.sql.delta.managedcommit
1818

1919
import org.apache.spark.sql.delta.{DeltaConfigs, DeltaTestUtilsBase}
20+
import org.apache.spark.sql.delta.DeltaConfigs.MANAGED_COMMIT_OWNER_NAME
2021
import org.apache.spark.sql.delta.storage.LogStore
2122
import org.apache.spark.sql.delta.util.JsonUtils
2223
import org.apache.hadoop.conf.Configuration
@@ -28,6 +29,18 @@ import org.apache.spark.sql.test.SharedSparkSession
2829
trait ManagedCommitTestUtils
2930
extends DeltaTestUtilsBase { self: SparkFunSuite with SharedSparkSession =>
3031

32+
def testWithoutManagedCommits(testName: String)(f: => Unit): Unit = {
33+
test(testName) {
34+
val oldCommitOwnerValue = spark.conf.get(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey)
35+
try {
36+
spark.conf.unset(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey)
37+
f
38+
} finally {
39+
spark.conf.set(MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey, oldCommitOwnerValue)
40+
}
41+
}
42+
}
43+
3144
/** Run the test with different backfill batch sizes: 1, 2, 10 */
3245
def testWithDifferentBackfillInterval(testName: String)(f: Int => Unit): Unit = {
3346
Seq(1, 2, 10).foreach { backfillBatchSize =>

0 commit comments

Comments
 (0)