Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package io.treeverse.clients

import io.treeverse.lakefs.catalog.Entry
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InvalidJobConfException
import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.SerializableConfiguration
import org.slf4j.{Logger, LoggerFactory}

import java.io.File
import java.util.concurrent.TimeUnit

object LakeFSJobParams {
Expand Down Expand Up @@ -61,6 +64,8 @@ class LakeFSJobParams private (
}

object LakeFSContext {
private val logger: Logger = LoggerFactory.getLogger(getClass.toString)

val LAKEFS_CONF_API_URL_KEY = "lakefs.api.url"
val LAKEFS_CONF_API_ACCESS_KEY_KEY = "lakefs.api.access_key"
val LAKEFS_CONF_API_SECRET_KEY_KEY = "lakefs.api.secret_key"
Expand All @@ -71,6 +76,9 @@ object LakeFSContext {
val LAKEFS_CONF_JOB_COMMIT_IDS_KEY = "lakefs.job.commit_ids"
val LAKEFS_CONF_JOB_SOURCE_NAME_KEY = "lakefs.job.source_name"

// Read parallelism. Defaults to default parallelism.
val LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM = "lakefs.job.range_read_parallelism"

val LAKEFS_CONF_GC_NUM_COMMIT_PARTITIONS = "lakefs.gc.commit.num_partitions"
val LAKEFS_CONF_GC_NUM_RANGE_PARTITIONS = "lakefs.gc.range.num_partitions"
val LAKEFS_CONF_GC_NUM_ADDRESS_PARTITIONS = "lakefs.gc.address.num_partitions"
Expand Down Expand Up @@ -108,15 +116,14 @@ object LakeFSContext {
val DEFAULT_LAKEFS_CONF_GC_S3_MIN_BACKOFF_SECONDS = 1
val DEFAULT_LAKEFS_CONF_GC_S3_MAX_BACKOFF_SECONDS = 120

val metarangeReaderGetter = SSTableReader.forMetaRange _

def newRDD(
sc: SparkContext,
params: LakeFSJobParams
): RDD[(Array[Byte], WithIdentifier[Entry])] = {
val inputFormatClass =
if (params.commitIDs.nonEmpty) classOf[LakeFSCommitInputFormat]
else classOf[LakeFSAllRangesInputFormat]
val conf = sc.hadoopConfiguration

val conf = new Configuration(sc.hadoopConfiguration)
conf.set(LAKEFS_CONF_JOB_REPO_NAME_KEY, params.repoName)
conf.setStrings(LAKEFS_CONF_JOB_COMMIT_IDS_KEY, params.commitIDs.toArray: _*)

Expand All @@ -131,12 +138,77 @@ object LakeFSContext {
throw new InvalidJobConfException(s"$LAKEFS_CONF_API_SECRET_KEY_KEY must not be empty")
}
conf.set(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, params.sourceName)
sc.newAPIHadoopRDD(
conf,
inputFormatClass,
classOf[Array[Byte]],
classOf[WithIdentifier[Entry]]

val apiConf = APIConfigurations(
conf.get(LAKEFS_CONF_API_URL_KEY),
conf.get(LAKEFS_CONF_API_ACCESS_KEY_KEY),
conf.get(LAKEFS_CONF_API_SECRET_KEY_KEY),
conf.get(LAKEFS_CONF_API_CONNECTION_TIMEOUT_SEC_KEY),
conf.get(LAKEFS_CONF_API_READ_TIMEOUT_SEC_KEY),
conf.get(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, "input_format")
)
val repoName = conf.get(LAKEFS_CONF_JOB_REPO_NAME_KEY)

// This can go to executors.
val serializedConf = new SerializableConfiguration(conf)

val parallelism = conf.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism)

// ApiClient is not serializable, so create a new one for each partition on its executor.
// (If we called X.flatMap directly, we would fetch the client from the cache for each
// range, which is a bit too much.)

// TODO(ariels): Unify with similar code in LakeFSInputFormat.getSplits
val ranges = sc
.parallelize(params.commitIDs.toSeq, parallelism)
.mapPartitions(commits => {
val apiClient = ApiClient.get(apiConf)
val conf = serializedConf.value
commits.flatMap(commitID => {
val metaRangeURL = apiClient.getMetaRangeURL(repoName, commitID)
if (metaRangeURL == "") {
// a commit with no meta range is an empty commit.
// this only happens for the first commit in the repository.
None
} else {
val rangesReader = metarangeReaderGetter(conf, metaRangeURL, true)
rangesReader
.newIterator()
.map(rd => new Range(new String(rd.id), rd.message.estimatedSize))
}
})
})
.distinct

ranges.mapPartitions(ranges => {
val apiClient = ApiClient.get(apiConf)
val conf = serializedConf.value
ranges.flatMap((range: Range) => {
val path = new Path(apiClient.getRangeURL(repoName, range.id))
val fs = path.getFileSystem(conf)
val localFile = File.createTempFile("lakefs.", ".range")

fs.copyToLocalFile(false, path, new Path(localFile.getAbsolutePath), true)
val companion = Entry.messageCompanion
// localFile owned by sstableReader which will delete it when closed.
val sstableReader = new SSTableReader(localFile.getAbsolutePath, companion, true)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener((tc: TaskContext) => {
try {
sstableReader.close()
} catch {
case e: Exception => {
logger.warn(s"close SSTable reader for $localFile (keep going): $e")
}
}
tc
}))
// TODO(ariels): Do we need to validate that this reader is good? Assume _not_, this is
// not InputFormat code so it should have slightly nicer error reports.
sstableReader
.newIterator()
.map((entry) => (entry.key, new WithIdentifier(entry.id, entry.message, range.id)))
})
})
}

/** Returns all entries in all ranges of the given commit, as an RDD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ abstract class LakeFSBaseInputFormat extends InputFormat[Array[Byte], WithIdenti
new EntryRecordReader(Entry.messageCompanion)
}
}
private class Range(val id: String, val estimatedSize: Long) {

class Range(val id: String, val estimatedSize: Long) extends Serializable {
// non-private so Spark will serialize it.

override def hashCode(): Int = {
id.hashCode()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}

import org.slf4j.{Logger, LoggerFactory}
import java.io.{ByteArrayInputStream, Closeable, DataInputStream, File}

class Item[T](val key: Array[Byte], val id: Array[Byte], val message: T)
Expand Down Expand Up @@ -101,6 +102,8 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva
val companion: GeneratedMessageCompanion[Proto],
val own: Boolean = true
) extends Closeable {
private val logger: Logger = LoggerFactory.getLogger(getClass.toString)

private val fp = new java.io.RandomAccessFile(file, "r")
private val reader = new BlockReadableFile(fp)

Expand All @@ -110,7 +113,13 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva
def close(): Unit = {
fp.close()
if (own) {
file.delete()
try {
file.delete()
} catch {
case e: Exception => {
logger.warn(s"delete owned file ${file.getName()} (keep going): $e")
}
}
}
}

Expand Down
9 changes: 5 additions & 4 deletions clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import scala.collection.mutable.ListBuffer
*/
abstract class DataLister {
@transient lazy val spark: SparkSession = SparkSession.active
def listData(configMapper: ConfigMapper, path: Path): DataFrame
def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame
}

class NaiveDataLister extends DataLister {
override def listData(configMapper: ConfigMapper, path: Path): DataFrame = {
override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = {
import spark.implicits._
val fs = path.getFileSystem(configMapper.configuration)
val dataIt = fs.listFiles(path, false)
Expand All @@ -24,7 +24,7 @@ class NaiveDataLister extends DataLister {
val fileStatus = dataIt.next()
dataList += ((fileStatus.getPath.getName, fileStatus.getModificationTime))
}
dataList.toDF("base_address", "last_modified")
dataList.toDF("base_address", "last_modified").repartition(parallelism)
}
}

Expand All @@ -47,7 +47,7 @@ class ParallelDataLister extends DataLister with Serializable {
}
}

override def listData(configMapper: ConfigMapper, path: Path): DataFrame = {
override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = {
import spark.implicits._
val slices = listPath(configMapper, path)
val objectsPath = if (path.toString.endsWith("/")) path.toString else path.toString + "/"
Expand All @@ -63,6 +63,7 @@ class ParallelDataLister extends DataLister with Serializable {
.map(_.path)
.toSeq
.toDF("slice_id")
.repartition(parallelism)
.withColumn("udf", explode(objectsUDF(col("slice_id"))))
.withColumn("base_address", concat(col("slice_id"), lit("/"), col("udf._1")))
.withColumn("last_modified", col("udf._2"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.apache.commons.lang3.time.DateUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.native.JsonMethods._
Expand Down Expand Up @@ -43,6 +44,8 @@ object GarbageCollection {
val sc = spark.sparkContext
val oldDataPath = new Path(storageNamespace)
val dataPath = new Path(storageNamespace, DATA_PREFIX)
val parallelism =
sc.hadoopConfiguration.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism)

val configMapper = new ConfigMapper(
sc.broadcast(
Expand All @@ -54,7 +57,7 @@ object GarbageCollection {
)
)
// Read objects from data path (new repository structure)
var dataDF = new ParallelDataLister().listData(configMapper, dataPath)
var dataDF = new ParallelDataLister().listData(configMapper, dataPath, parallelism)
dataDF = dataDF
.withColumn(
"address",
Expand All @@ -65,7 +68,7 @@ object GarbageCollection {

// TODO (niro): implement parallel lister for old repositories (https://github.com/treeverse/lakeFS/issues/4620)
val oldDataDF = new NaiveDataLister()
.listData(configMapper, oldDataPath)
.listData(configMapper, oldDataPath, parallelism)
.withColumn("address", col("base_address"))
.filter(!col("address").isin(excludeFromOldData: _*))
dataDF = dataDF.union(oldDataDF).filter(col("last_modified") < before.getTime)
Expand Down Expand Up @@ -195,7 +198,7 @@ object GarbageCollection {
.repartition(dataDF.col("address"))
.except(committedDF)
.except(uncommittedDF)
.cache()
.persist(StorageLevel.MEMORY_AND_DISK)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the difference between persist and cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. (Linking to PySpark docs for a new version because it's easiest to find these online. But it's been like this for... ever.) Man says:

Persist this RDD with the default storage level (MEMORY_ONLY).

So it only works for small RDDs. But we mostly care about large RDDs.

Personally I think that if you have "persist", and you name a shortcut to it "cache", then you have $\ge 1$ naming problems.


committedDF.unpersist()
uncommittedDF.unpersist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ParallelDataListerSpec
)
)
val df =
new ParallelDataLister().listData(configMapper, path).sort("base_address")
new ParallelDataLister().listData(configMapper, path, 3).sort("base_address")
df.count should be(100)
val slices =
df.select(substring(col("base_address"), 0, 7).as("slice_id"))
Expand Down Expand Up @@ -81,7 +81,7 @@ class ParallelDataListerSpec
)
)
val df =
new ParallelDataLister().listData(configMapper, path).sort("base_address")
new ParallelDataLister().listData(configMapper, path, 3).sort("base_address")
df.count() should be(1)
df.head.getString(0) should be(s"$sliceID/$filename")
})
Expand Down Expand Up @@ -127,7 +127,7 @@ class NaiveDataListerSpec
HadoopUtils.getHadoopConfigurationValues(spark.sparkContext.hadoopConfiguration)
)
)
val df = new NaiveDataLister().listData(configMapper, path).sort("base_address")
val df = new NaiveDataLister().listData(configMapper, path, 3).sort("base_address")
df.count should be(10)
df.sort("base_address").head.getString(0) should be("object01")
df.head.getString(0) should be("object01")
Expand Down
Loading