diff --git a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala index 090761ab566..43781b77d3a 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala @@ -2,13 +2,16 @@ package io.treeverse.clients import io.treeverse.lakefs.catalog.Entry import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidJobConfException -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} +import java.io.File import java.util.concurrent.TimeUnit object LakeFSJobParams { @@ -61,6 +64,8 @@ class LakeFSJobParams private ( } object LakeFSContext { + private val logger: Logger = LoggerFactory.getLogger(getClass.toString) + val LAKEFS_CONF_API_URL_KEY = "lakefs.api.url" val LAKEFS_CONF_API_ACCESS_KEY_KEY = "lakefs.api.access_key" val LAKEFS_CONF_API_SECRET_KEY_KEY = "lakefs.api.secret_key" @@ -71,6 +76,9 @@ object LakeFSContext { val LAKEFS_CONF_JOB_COMMIT_IDS_KEY = "lakefs.job.commit_ids" val LAKEFS_CONF_JOB_SOURCE_NAME_KEY = "lakefs.job.source_name" + // Read parallelism. Defaults to default parallelism. + val LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM = "lakefs.job.range_read_parallelism" + val LAKEFS_CONF_GC_NUM_COMMIT_PARTITIONS = "lakefs.gc.commit.num_partitions" val LAKEFS_CONF_GC_NUM_RANGE_PARTITIONS = "lakefs.gc.range.num_partitions" val LAKEFS_CONF_GC_NUM_ADDRESS_PARTITIONS = "lakefs.gc.address.num_partitions" @@ -108,15 +116,14 @@ object LakeFSContext { val DEFAULT_LAKEFS_CONF_GC_S3_MIN_BACKOFF_SECONDS = 1 val DEFAULT_LAKEFS_CONF_GC_S3_MAX_BACKOFF_SECONDS = 120 + val metarangeReaderGetter = SSTableReader.forMetaRange _ + def newRDD( sc: SparkContext, params: LakeFSJobParams ): RDD[(Array[Byte], WithIdentifier[Entry])] = { - val inputFormatClass = - if (params.commitIDs.nonEmpty) classOf[LakeFSCommitInputFormat] - else classOf[LakeFSAllRangesInputFormat] + val conf = sc.hadoopConfiguration - val conf = new Configuration(sc.hadoopConfiguration) conf.set(LAKEFS_CONF_JOB_REPO_NAME_KEY, params.repoName) conf.setStrings(LAKEFS_CONF_JOB_COMMIT_IDS_KEY, params.commitIDs.toArray: _*) @@ -131,12 +138,77 @@ object LakeFSContext { throw new InvalidJobConfException(s"$LAKEFS_CONF_API_SECRET_KEY_KEY must not be empty") } conf.set(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, params.sourceName) - sc.newAPIHadoopRDD( - conf, - inputFormatClass, - classOf[Array[Byte]], - classOf[WithIdentifier[Entry]] + + val apiConf = APIConfigurations( + conf.get(LAKEFS_CONF_API_URL_KEY), + conf.get(LAKEFS_CONF_API_ACCESS_KEY_KEY), + conf.get(LAKEFS_CONF_API_SECRET_KEY_KEY), + conf.get(LAKEFS_CONF_API_CONNECTION_TIMEOUT_SEC_KEY), + conf.get(LAKEFS_CONF_API_READ_TIMEOUT_SEC_KEY), + conf.get(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, "input_format") ) + val repoName = conf.get(LAKEFS_CONF_JOB_REPO_NAME_KEY) + + // This can go to executors. + val serializedConf = new SerializableConfiguration(conf) + + val parallelism = conf.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism) + + // ApiClient is not serializable, so create a new one for each partition on its executor. + // (If we called X.flatMap directly, we would fetch the client from the cache for each + // range, which is a bit too much.) + + // TODO(ariels): Unify with similar code in LakeFSInputFormat.getSplits + val ranges = sc + .parallelize(params.commitIDs.toSeq, parallelism) + .mapPartitions(commits => { + val apiClient = ApiClient.get(apiConf) + val conf = serializedConf.value + commits.flatMap(commitID => { + val metaRangeURL = apiClient.getMetaRangeURL(repoName, commitID) + if (metaRangeURL == "") { + // a commit with no meta range is an empty commit. + // this only happens for the first commit in the repository. + None + } else { + val rangesReader = metarangeReaderGetter(conf, metaRangeURL, true) + rangesReader + .newIterator() + .map(rd => new Range(new String(rd.id), rd.message.estimatedSize)) + } + }) + }) + .distinct + + ranges.mapPartitions(ranges => { + val apiClient = ApiClient.get(apiConf) + val conf = serializedConf.value + ranges.flatMap((range: Range) => { + val path = new Path(apiClient.getRangeURL(repoName, range.id)) + val fs = path.getFileSystem(conf) + val localFile = File.createTempFile("lakefs.", ".range") + + fs.copyToLocalFile(false, path, new Path(localFile.getAbsolutePath), true) + val companion = Entry.messageCompanion + // localFile owned by sstableReader which will delete it when closed. + val sstableReader = new SSTableReader(localFile.getAbsolutePath, companion, true) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener((tc: TaskContext) => { + try { + sstableReader.close() + } catch { + case e: Exception => { + logger.warn(s"close SSTable reader for $localFile (keep going): $e") + } + } + tc + })) + // TODO(ariels): Do we need to validate that this reader is good? Assume _not_, this is + // not InputFormat code so it should have slightly nicer error reports. + sstableReader + .newIterator() + .map((entry) => (entry.key, new WithIdentifier(entry.id, entry.message, range.id))) + }) + }) } /** Returns all entries in all ranges of the given commit, as an RDD. diff --git a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala index c8ad4f6ad96..a48e1d4e7cd 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala @@ -169,7 +169,10 @@ abstract class LakeFSBaseInputFormat extends InputFormat[Array[Byte], WithIdenti new EntryRecordReader(Entry.messageCompanion) } } -private class Range(val id: String, val estimatedSize: Long) { + +class Range(val id: String, val estimatedSize: Long) extends Serializable { + // non-private so Spark will serialize it. + override def hashCode(): Int = { id.hashCode() } diff --git a/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala b/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala index 1ef35e7e855..bf80d2821d0 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala @@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.TaskContext import scalapb.{GeneratedMessage, GeneratedMessageCompanion} +import org.slf4j.{Logger, LoggerFactory} import java.io.{ByteArrayInputStream, Closeable, DataInputStream, File} class Item[T](val key: Array[Byte], val id: Array[Byte], val message: T) @@ -101,6 +102,8 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva val companion: GeneratedMessageCompanion[Proto], val own: Boolean = true ) extends Closeable { + private val logger: Logger = LoggerFactory.getLogger(getClass.toString) + private val fp = new java.io.RandomAccessFile(file, "r") private val reader = new BlockReadableFile(fp) @@ -110,7 +113,13 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva def close(): Unit = { fp.close() if (own) { - file.delete() + try { + file.delete() + } catch { + case e: Exception => { + logger.warn(s"delete owned file ${file.getName()} (keep going): $e") + } + } } } diff --git a/clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala b/clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala index fe23c9b646e..bd582738989 100644 --- a/clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala +++ b/clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala @@ -11,11 +11,11 @@ import scala.collection.mutable.ListBuffer */ abstract class DataLister { @transient lazy val spark: SparkSession = SparkSession.active - def listData(configMapper: ConfigMapper, path: Path): DataFrame + def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame } class NaiveDataLister extends DataLister { - override def listData(configMapper: ConfigMapper, path: Path): DataFrame = { + override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = { import spark.implicits._ val fs = path.getFileSystem(configMapper.configuration) val dataIt = fs.listFiles(path, false) @@ -24,7 +24,7 @@ class NaiveDataLister extends DataLister { val fileStatus = dataIt.next() dataList += ((fileStatus.getPath.getName, fileStatus.getModificationTime)) } - dataList.toDF("base_address", "last_modified") + dataList.toDF("base_address", "last_modified").repartition(parallelism) } } @@ -47,7 +47,7 @@ class ParallelDataLister extends DataLister with Serializable { } } - override def listData(configMapper: ConfigMapper, path: Path): DataFrame = { + override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = { import spark.implicits._ val slices = listPath(configMapper, path) val objectsPath = if (path.toString.endsWith("/")) path.toString else path.toString + "/" @@ -63,6 +63,7 @@ class ParallelDataLister extends DataLister with Serializable { .map(_.path) .toSeq .toDF("slice_id") + .repartition(parallelism) .withColumn("udf", explode(objectsUDF(col("slice_id")))) .withColumn("base_address", concat(col("slice_id"), lit("/"), col("udf._1"))) .withColumn("last_modified", col("udf._2")) diff --git a/clients/spark/src/main/scala/io/treeverse/gc/GarbageCollection.scala b/clients/spark/src/main/scala/io/treeverse/gc/GarbageCollection.scala index cad9c48316a..530a0f13faa 100644 --- a/clients/spark/src/main/scala/io/treeverse/gc/GarbageCollection.scala +++ b/clients/spark/src/main/scala/io/treeverse/gc/GarbageCollection.scala @@ -6,6 +6,7 @@ import org.apache.commons.lang3.time.DateUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.storage.StorageLevel import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.native.JsonMethods._ @@ -43,6 +44,8 @@ object GarbageCollection { val sc = spark.sparkContext val oldDataPath = new Path(storageNamespace) val dataPath = new Path(storageNamespace, DATA_PREFIX) + val parallelism = + sc.hadoopConfiguration.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism) val configMapper = new ConfigMapper( sc.broadcast( @@ -54,7 +57,7 @@ object GarbageCollection { ) ) // Read objects from data path (new repository structure) - var dataDF = new ParallelDataLister().listData(configMapper, dataPath) + var dataDF = new ParallelDataLister().listData(configMapper, dataPath, parallelism) dataDF = dataDF .withColumn( "address", @@ -65,7 +68,7 @@ object GarbageCollection { // TODO (niro): implement parallel lister for old repositories (https://github.com/treeverse/lakeFS/issues/4620) val oldDataDF = new NaiveDataLister() - .listData(configMapper, oldDataPath) + .listData(configMapper, oldDataPath, parallelism) .withColumn("address", col("base_address")) .filter(!col("address").isin(excludeFromOldData: _*)) dataDF = dataDF.union(oldDataDF).filter(col("last_modified") < before.getTime) @@ -195,7 +198,7 @@ object GarbageCollection { .repartition(dataDF.col("address")) .except(committedDF) .except(uncommittedDF) - .cache() + .persist(StorageLevel.MEMORY_AND_DISK) committedDF.unpersist() uncommittedDF.unpersist() diff --git a/clients/spark/src/test/scala/io/treeverse/gc/DataListerSpec.scala b/clients/spark/src/test/scala/io/treeverse/gc/DataListerSpec.scala index 7afed3ded25..c72139b83be 100644 --- a/clients/spark/src/test/scala/io/treeverse/gc/DataListerSpec.scala +++ b/clients/spark/src/test/scala/io/treeverse/gc/DataListerSpec.scala @@ -52,7 +52,7 @@ class ParallelDataListerSpec ) ) val df = - new ParallelDataLister().listData(configMapper, path).sort("base_address") + new ParallelDataLister().listData(configMapper, path, 3).sort("base_address") df.count should be(100) val slices = df.select(substring(col("base_address"), 0, 7).as("slice_id")) @@ -81,7 +81,7 @@ class ParallelDataListerSpec ) ) val df = - new ParallelDataLister().listData(configMapper, path).sort("base_address") + new ParallelDataLister().listData(configMapper, path, 3).sort("base_address") df.count() should be(1) df.head.getString(0) should be(s"$sliceID/$filename") }) @@ -127,7 +127,7 @@ class NaiveDataListerSpec HadoopUtils.getHadoopConfigurationValues(spark.sparkContext.hadoopConfiguration) ) ) - val df = new NaiveDataLister().listData(configMapper, path).sort("base_address") + val df = new NaiveDataLister().listData(configMapper, path, 3).sort("base_address") df.count should be(10) df.sort("base_address").head.getString(0) should be("object01") df.head.getString(0) should be("object01")