Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,51 @@ private static DataFileStatistics constructFileStats(
.collect(toImmutableMap(
identity(),
key -> mergeMetadataList(metadataForColumn.get(key))));

Map<Column, Literal> minValues = new HashMap<>();
Map<Column, Literal> maxValues = new HashMap<>();
Map<Column, Long> nullCounts = new HashMap<>();
for (Column statsColumn : statsColumns) {
Optional<Statistics<?>> stats = statsForColumn.get(statsColumn);
DataType columnType = getDataType(dataSchema, statsColumn);
if (stats == null || !stats.isPresent() || !isStatsSupportedDataType(columnType)) {

Optional<Statistics<?>> stats;
if (columnType instanceof VariantType) {
// Hack: Parquet stores Variant types as a struct with two binary child fields so
// the stats are also only stored on the child fields. Because the value and
// metadata fields are null iff the variant is null, the null count stat is
// retrieved by inspecting the "value" field.
int variantColNameLength = statsColumn.getNames().length;
String[] variantColNameList =
Arrays.copyOf(statsColumn.getNames(), variantColNameLength + 1);
variantColNameList[variantColNameLength] = "value";
stats = statsForColumn.get(new Column(variantColNameList));
} else {
stats = statsForColumn.get(statsColumn);
}
if (stats == null || !stats.isPresent()) {
continue;
}
Statistics<?> statistics = stats.get();

Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null;
nullCounts.put(statsColumn, numNulls);

if (numNulls != null && rowCount == numNulls) {
// If all values are null, then min and max are also null
minValues.put(statsColumn, Literal.ofNull(columnType));
maxValues.put(statsColumn, Literal.ofNull(columnType));
continue;
if (isNullStatSupportedDataType(columnType)) {
Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null;
nullCounts.put(statsColumn, numNulls);
}

Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */);
minValues.put(statsColumn, minValue);
if (isMinMaxStatSupportedDataType(columnType)) {
Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null;
if (numNulls != null && rowCount == numNulls) {
// If all values are null, then min and max are also null
minValues.put(statsColumn, Literal.ofNull(columnType));
maxValues.put(statsColumn, Literal.ofNull(columnType));
continue;
}

Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */);
minValues.put(statsColumn, minValue);

Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */);
maxValues.put(statsColumn, maxValue);
Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */);
maxValues.put(statsColumn, maxValue);
}
}

return new DataFileStatistics(rowCount, minValues, maxValues, nullCounts);
Expand Down Expand Up @@ -220,7 +238,7 @@ private static boolean hasInvalidStatistics(Collection<ColumnChunkMetaData> meta
});
}

private static boolean isStatsSupportedDataType(DataType dataType) {
public static boolean isMinMaxStatSupportedDataType(DataType dataType) {
return dataType instanceof BooleanType ||
dataType instanceof ByteType ||
dataType instanceof ShortType ||
Expand All @@ -236,6 +254,22 @@ private static boolean isStatsSupportedDataType(DataType dataType) {
// Add support later.
}

public static boolean isNullStatSupportedDataType(DataType dataType) {
return dataType instanceof BooleanType ||
dataType instanceof ByteType ||
dataType instanceof ShortType ||
dataType instanceof IntegerType ||
dataType instanceof LongType ||
dataType instanceof FloatType ||
dataType instanceof DoubleType ||
dataType instanceof DecimalType ||
dataType instanceof DateType ||
dataType instanceof StringType ||
dataType instanceof BinaryType ||
dataType instanceof VariantType;
// TODO: add timestamp support later.
}

private static byte[] getBinaryStat(Statistics<?> statistics, boolean decodeMin) {
return decodeMin ? statistics.getMinBytes() : statistics.getMaxBytes();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,44 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
}
}

test("data skipping - basic variant type") {
withTable("test_table") {
spark.range(0, 10, 1, 1)
.selectExpr(
"parse_json(cast(id as string)) as basic_v",
"cast(null as variant) as basic_null_v",
"named_struct('v', parse_json(cast(id as string))) as basic_struct_v",
"named_struct('v', cast(null as variant)) as basic_struct_null_v",
)
.write
.format("delta")
.mode("overwrite")
.saveAsTable("test_table")

val filePath = spark.sql("describe table extended `test_table`")
.where("col_name = 'Location'")
.collect()(0)
.getString(1)
.replace("file:", "")

checkSkipping(
filePath,
hits = Seq(
isNotNull(col("basic_v")),
not(isNotNull(col("basic_null_v"))),
isNotNull(nestedCol("basic_struct_v.v")),
not(isNotNull(nestedCol("basic_struct_null_v.v")))
),
misses = Seq(
not(isNotNull(col("basic_v"))),
isNotNull(col("basic_null_v")),
not(isNotNull(nestedCol("basic_struct_v.v"))),
isNotNull(nestedCol("basic_struct_null_v.v"))
)
)
}
}

test("data skipping - is not null with DVs in file with non-nulls") {
withSQLConf(("spark.databricks.delta.properties.defaults.enableDeletionVectors", "true")) {
withTempDir { tempDir =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
package io.delta.kernel.defaults.internal.parquet

import java.lang.{Double => DoubleJ, Float => FloatJ}
import java.util.ArrayList
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{Row => SparkRow}

import io.delta.golden.GoldenTableUtils.{goldenTableFile, goldenTablePath}
import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch}
Expand Down Expand Up @@ -191,7 +195,9 @@ class ParquetFileWriterSuite extends AnyFunSuite
}
}

def testWrite(testName: String)(df: => DataFrame): Unit = {
def testWrite(
testName: String,
statsColsAndKernelType: Seq[(Column, DataType)])(df: => DataFrame): Unit = {
test(testName) {
withTable("test_table") {
withTempDir { writeDir =>
Expand All @@ -215,14 +221,68 @@ class ParquetFileWriterSuite extends AnyFunSuite
val readData = readParquetUsingKernelAsColumnarBatches(filePath, physicalSchema)
.map(_.toFiltered(Option.empty[Predicate]))
val writePath = writeDir.getAbsolutePath
val writeOutput = writeToParquetUsingKernel(readData, writePath)
val writeOutput = writeToParquetUsingKernel(
readData,
writePath,
statsColumns = statsColsAndKernelType.map { case (col, dt) => col}
)
verifyContentUsingKernelReader(writePath, readData)

val numMinMaxStatCols = statsColsAndKernelType.count { case (_, dt) =>
ParquetStatsReader.isMinMaxStatSupportedDataType(dt)
}
val numNullStatCols = statsColsAndKernelType.count { case (_, dt) =>
ParquetStatsReader.isNullStatSupportedDataType(dt)
}
verifyVariantStatsUsingKernel(
writePath,
writeOutput,
schema,
statsColsAndKernelType,
numMinMaxStatCols,
numNullStatCols
)
}
}
}
}

testWrite("basic write variant") {
testWrite(
"basic write verify stats with kernel",
Seq(
(new Column("boolcol"), BooleanType.BOOLEAN),
(new Column("bytecol"), ByteType.BYTE),
(new Column("intcol"), IntegerType.INTEGER),
(new Column("longcol"), LongType.LONG),
(new Column("shortcol"), ShortType.SHORT),
(new Column("floatcol"), FloatType.FLOAT),
(new Column("doublecol"), DoubleType.DOUBLE),
(new Column("stringcol"), StringType.STRING),
(new Column("binarycol"), BinaryType.BINARY),
(new Column("decimalcol"), DecimalType.USER_DEFAULT),
(new Column("datecol"), DateType.DATE)
)) {
spark.range(0, 100, 1, 1).selectExpr(
"cast(id as boolean) as boolcol",
"cast(id as byte) as bytecol",
"cast(5 as int) as intcol",
"cast(id as long) as longcol",
"cast(id as short) as shortcol",
"cast(id as float) as floatcol",
"cast(id as double) as doublecol",
"cast(id as string) as stringcol",
"cast(id as binary) as binarycol",
"cast(id as decimal(10, 0)) as decimalcol",
"current_date() as datecol"
)
}

testWrite(
"basic write variant",
Seq(
(new Column("basic_v"), VariantType.VARIANT),
(new Column(Array("struct_v", "v")), VariantType.VARIANT)
)) {
spark.range(0, 10, 1, 1).selectExpr(
"parse_json(cast(id as string)) as basic_v",
"named_struct('v', parse_json(cast(id as string))) as struct_v",
Expand All @@ -236,7 +296,12 @@ class ParquetFileWriterSuite extends AnyFunSuite
)
}

testWrite("basic write null variant") {
testWrite(
"basic write null variant",
Seq(
(new Column("basic_v"), VariantType.VARIANT),
(new Column(Array("struct_v", "v")), VariantType.VARIANT)
)) {
spark.range(0, 10, 1, 1).selectExpr(
"cast(null as variant) basic_v",
"named_struct('v', cast(null as variant)) as struct_v",
Expand Down Expand Up @@ -372,6 +437,89 @@ class ParquetFileWriterSuite extends AnyFunSuite
}
}

def verifyVariantStatsUsingKernel(
actualFileDir: String,
actualFileStatuses: Seq[DataFileStatus],
schema: StructType,
statsColumnsAndTypes: Seq[(Column, DataType)],
expMinMaxStatsCount: Int,
expNullCountStatsCount: Int): Unit = {
if (statsColumnsAndTypes.isEmpty) return

val statFields = ArrayBuffer(
new StructField("location", StringType.STRING, true),
new StructField("fileSize", LongType.LONG, true),
new StructField("lastModifiedTime", LongType.LONG, true),
new StructField("rowCount", LongType.LONG, true)
)

val statsColumns = statsColumnsAndTypes.map { _._1 }
statsColumnsAndTypes.foreach { case (column, dt) =>
val colName = column.getNames().toSeq.mkString("__")
statFields ++= ArrayBuffer(
new StructField(s"min_$colName", dt, true),
new StructField(s"max_$colName", dt, true),
new StructField(s"nullCount_$colName", LongType.LONG, true)
)
}

val kernelStatsSchema = new StructType(new ArrayList(statFields.asJava))

val actualStatsOutput = actualFileStatuses
.map { fileStatus =>
// validate there are the expected number of stats columns
assert(fileStatus.getStatistics.isPresent)
assert(fileStatus.getStatistics.get().getMinValues.size() === expMinMaxStatsCount)
assert(fileStatus.getStatistics.get().getMaxValues.size() === expMinMaxStatsCount)
assert(fileStatus.getStatistics.get().getNullCounts.size() === expNullCountStatsCount)

// Convert to Spark row for comparison with the actual values computing using Spark.
TestRow.toSparkRow(
fileStatus.toTestRow(statsColumns),
kernelStatsSchema)
}

val statDf = spark.createDataFrame(
spark.sparkContext.parallelize(actualStatsOutput),
kernelStatsSchema.toSpark).createOrReplaceTempView("stats_table")

val readRows = readParquetFilesUsingKernel(actualFileDir, schema)
val sparkRows = readRows.map { row =>
TestRow.toSparkRow(row, schema)
}
val readDf = spark.createDataFrame(spark.sparkContext.parallelize(sparkRows), schema.toSpark)
readDf.createOrReplaceTempView("data_table")
statsColumnsAndTypes.foreach { case (col, dt) =>
val realColName = col.getNames().mkString(".")
val statsColNameSuffix = col.getNames().mkString("__")

if (ParquetStatsReader.isNullStatSupportedDataType(dt)) {
val actualNullCount = readDf.where(s"$realColName is null").count()
val statsNullCount = spark
.sql(s"select nullCount_$statsColNameSuffix from stats_table")
.collect()(0)
.getLong(0)
assert(actualNullCount == statsNullCount)
}

if (ParquetStatsReader.isMinMaxStatSupportedDataType(dt)) {
spark.sql(s"""select min($realColName) as minVal, max($realColName) as maxVal
from data_table""").createOrReplaceTempView("min_max_table")

// Verify that the real min/max values are equal to the reported min/max values.
val minAssertionDf = spark.sql(s"""
select first(minVal) from min_max_table d, stats_table s
where d.minVal = s.min_$statsColNameSuffix""")
assert(minAssertionDf.count() == 1)

val maxAssertionDf = spark.sql(s"""
select first(maxVal) from min_max_table d, stats_table s
where d.maxVal = s.max_$statsColNameSuffix""")
assert(maxAssertionDf.count() == 1)
}
}
}

def verifyStatsUsingSpark(
actualFileDir: String,
actualFileStatuses: Seq[DataFileStatus],
Expand Down
Loading