diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java index 1e982ad25c2..4c6a8c7a451 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java @@ -92,33 +92,51 @@ private static DataFileStatistics constructFileStats( .collect(toImmutableMap( identity(), key -> mergeMetadataList(metadataForColumn.get(key)))); - Map minValues = new HashMap<>(); Map maxValues = new HashMap<>(); Map nullCounts = new HashMap<>(); for (Column statsColumn : statsColumns) { - Optional> stats = statsForColumn.get(statsColumn); DataType columnType = getDataType(dataSchema, statsColumn); - if (stats == null || !stats.isPresent() || !isStatsSupportedDataType(columnType)) { + + Optional> stats; + if (columnType instanceof VariantType) { + // Hack: Parquet stores Variant types as a struct with two binary child fields so + // the stats are also only stored on the child fields. Because the value and + // metadata fields are null iff the variant is null, the null count stat is + // retrieved by inspecting the "value" field. + int variantColNameLength = statsColumn.getNames().length; + String[] variantColNameList = + Arrays.copyOf(statsColumn.getNames(), variantColNameLength + 1); + variantColNameList[variantColNameLength] = "value"; + stats = statsForColumn.get(new Column(variantColNameList)); + } else { + stats = statsForColumn.get(statsColumn); + } + if (stats == null || !stats.isPresent()) { continue; } Statistics statistics = stats.get(); - Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; - nullCounts.put(statsColumn, numNulls); - - if (numNulls != null && rowCount == numNulls) { - // If all values are null, then min and max are also null - minValues.put(statsColumn, Literal.ofNull(columnType)); - maxValues.put(statsColumn, Literal.ofNull(columnType)); - continue; + if (isNullStatSupportedDataType(columnType)) { + Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; + nullCounts.put(statsColumn, numNulls); } - Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */); - minValues.put(statsColumn, minValue); + if (isMinMaxStatSupportedDataType(columnType)) { + Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; + if (numNulls != null && rowCount == numNulls) { + // If all values are null, then min and max are also null + minValues.put(statsColumn, Literal.ofNull(columnType)); + maxValues.put(statsColumn, Literal.ofNull(columnType)); + continue; + } + + Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */); + minValues.put(statsColumn, minValue); - Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */); - maxValues.put(statsColumn, maxValue); + Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */); + maxValues.put(statsColumn, maxValue); + } } return new DataFileStatistics(rowCount, minValues, maxValues, nullCounts); @@ -220,7 +238,7 @@ private static boolean hasInvalidStatistics(Collection meta }); } - private static boolean isStatsSupportedDataType(DataType dataType) { + public static boolean isMinMaxStatSupportedDataType(DataType dataType) { return dataType instanceof BooleanType || dataType instanceof ByteType || dataType instanceof ShortType || @@ -236,6 +254,22 @@ private static boolean isStatsSupportedDataType(DataType dataType) { // Add support later. } + public static boolean isNullStatSupportedDataType(DataType dataType) { + return dataType instanceof BooleanType || + dataType instanceof ByteType || + dataType instanceof ShortType || + dataType instanceof IntegerType || + dataType instanceof LongType || + dataType instanceof FloatType || + dataType instanceof DoubleType || + dataType instanceof DecimalType || + dataType instanceof DateType || + dataType instanceof StringType || + dataType instanceof BinaryType || + dataType instanceof VariantType; + // TODO: add timestamp support later. + } + private static byte[] getBinaryStat(Statistics statistics, boolean decodeMin) { return decodeMin ? statistics.getMinBytes() : statistics.getMaxBytes(); } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index e08adaadbde..1ba20d2b91a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -1386,6 +1386,44 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with } } + test("data skipping - basic variant type") { + withTable("test_table") { + spark.range(0, 10, 1, 1) + .selectExpr( + "parse_json(cast(id as string)) as basic_v", + "cast(null as variant) as basic_null_v", + "named_struct('v', parse_json(cast(id as string))) as basic_struct_v", + "named_struct('v', cast(null as variant)) as basic_struct_null_v", + ) + .write + .format("delta") + .mode("overwrite") + .saveAsTable("test_table") + + val filePath = spark.sql("describe table extended `test_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + checkSkipping( + filePath, + hits = Seq( + isNotNull(col("basic_v")), + not(isNotNull(col("basic_null_v"))), + isNotNull(nestedCol("basic_struct_v.v")), + not(isNotNull(nestedCol("basic_struct_null_v.v"))) + ), + misses = Seq( + not(isNotNull(col("basic_v"))), + isNotNull(col("basic_null_v")), + not(isNotNull(nestedCol("basic_struct_v.v"))), + isNotNull(nestedCol("basic_struct_null_v.v")) + ) + ) + } + } + test("data skipping - is not null with DVs in file with non-nulls") { withSQLConf(("spark.databricks.delta.properties.defaults.enableDeletionVectors", "true")) { withTempDir { tempDir => diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala index d6d5ccb10cf..f912fbd78a9 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala @@ -16,8 +16,12 @@ package io.delta.kernel.defaults.internal.parquet import java.lang.{Double => DoubleJ, Float => FloatJ} +import java.util.ArrayList +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row => SparkRow} import io.delta.golden.GoldenTableUtils.{goldenTableFile, goldenTablePath} import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch} @@ -191,7 +195,9 @@ class ParquetFileWriterSuite extends AnyFunSuite } } - def testWrite(testName: String)(df: => DataFrame): Unit = { + def testWrite( + testName: String, + statsColsAndKernelType: Seq[(Column, DataType)])(df: => DataFrame): Unit = { test(testName) { withTable("test_table") { withTempDir { writeDir => @@ -215,14 +221,68 @@ class ParquetFileWriterSuite extends AnyFunSuite val readData = readParquetUsingKernelAsColumnarBatches(filePath, physicalSchema) .map(_.toFiltered(Option.empty[Predicate])) val writePath = writeDir.getAbsolutePath - val writeOutput = writeToParquetUsingKernel(readData, writePath) + val writeOutput = writeToParquetUsingKernel( + readData, + writePath, + statsColumns = statsColsAndKernelType.map { case (col, dt) => col} + ) verifyContentUsingKernelReader(writePath, readData) + + val numMinMaxStatCols = statsColsAndKernelType.count { case (_, dt) => + ParquetStatsReader.isMinMaxStatSupportedDataType(dt) + } + val numNullStatCols = statsColsAndKernelType.count { case (_, dt) => + ParquetStatsReader.isNullStatSupportedDataType(dt) + } + verifyVariantStatsUsingKernel( + writePath, + writeOutput, + schema, + statsColsAndKernelType, + numMinMaxStatCols, + numNullStatCols + ) } } } } - testWrite("basic write variant") { + testWrite( + "basic write verify stats with kernel", + Seq( + (new Column("boolcol"), BooleanType.BOOLEAN), + (new Column("bytecol"), ByteType.BYTE), + (new Column("intcol"), IntegerType.INTEGER), + (new Column("longcol"), LongType.LONG), + (new Column("shortcol"), ShortType.SHORT), + (new Column("floatcol"), FloatType.FLOAT), + (new Column("doublecol"), DoubleType.DOUBLE), + (new Column("stringcol"), StringType.STRING), + (new Column("binarycol"), BinaryType.BINARY), + (new Column("decimalcol"), DecimalType.USER_DEFAULT), + (new Column("datecol"), DateType.DATE) + )) { + spark.range(0, 100, 1, 1).selectExpr( + "cast(id as boolean) as boolcol", + "cast(id as byte) as bytecol", + "cast(5 as int) as intcol", + "cast(id as long) as longcol", + "cast(id as short) as shortcol", + "cast(id as float) as floatcol", + "cast(id as double) as doublecol", + "cast(id as string) as stringcol", + "cast(id as binary) as binarycol", + "cast(id as decimal(10, 0)) as decimalcol", + "current_date() as datecol" + ) + } + + testWrite( + "basic write variant", + Seq( + (new Column("basic_v"), VariantType.VARIANT), + (new Column(Array("struct_v", "v")), VariantType.VARIANT) + )) { spark.range(0, 10, 1, 1).selectExpr( "parse_json(cast(id as string)) as basic_v", "named_struct('v', parse_json(cast(id as string))) as struct_v", @@ -236,7 +296,12 @@ class ParquetFileWriterSuite extends AnyFunSuite ) } - testWrite("basic write null variant") { + testWrite( + "basic write null variant", + Seq( + (new Column("basic_v"), VariantType.VARIANT), + (new Column(Array("struct_v", "v")), VariantType.VARIANT) + )) { spark.range(0, 10, 1, 1).selectExpr( "cast(null as variant) basic_v", "named_struct('v', cast(null as variant)) as struct_v", @@ -372,6 +437,89 @@ class ParquetFileWriterSuite extends AnyFunSuite } } + def verifyVariantStatsUsingKernel( + actualFileDir: String, + actualFileStatuses: Seq[DataFileStatus], + schema: StructType, + statsColumnsAndTypes: Seq[(Column, DataType)], + expMinMaxStatsCount: Int, + expNullCountStatsCount: Int): Unit = { + if (statsColumnsAndTypes.isEmpty) return + + val statFields = ArrayBuffer( + new StructField("location", StringType.STRING, true), + new StructField("fileSize", LongType.LONG, true), + new StructField("lastModifiedTime", LongType.LONG, true), + new StructField("rowCount", LongType.LONG, true) + ) + + val statsColumns = statsColumnsAndTypes.map { _._1 } + statsColumnsAndTypes.foreach { case (column, dt) => + val colName = column.getNames().toSeq.mkString("__") + statFields ++= ArrayBuffer( + new StructField(s"min_$colName", dt, true), + new StructField(s"max_$colName", dt, true), + new StructField(s"nullCount_$colName", LongType.LONG, true) + ) + } + + val kernelStatsSchema = new StructType(new ArrayList(statFields.asJava)) + + val actualStatsOutput = actualFileStatuses + .map { fileStatus => + // validate there are the expected number of stats columns + assert(fileStatus.getStatistics.isPresent) + assert(fileStatus.getStatistics.get().getMinValues.size() === expMinMaxStatsCount) + assert(fileStatus.getStatistics.get().getMaxValues.size() === expMinMaxStatsCount) + assert(fileStatus.getStatistics.get().getNullCounts.size() === expNullCountStatsCount) + + // Convert to Spark row for comparison with the actual values computing using Spark. + TestRow.toSparkRow( + fileStatus.toTestRow(statsColumns), + kernelStatsSchema) + } + + val statDf = spark.createDataFrame( + spark.sparkContext.parallelize(actualStatsOutput), + kernelStatsSchema.toSpark).createOrReplaceTempView("stats_table") + + val readRows = readParquetFilesUsingKernel(actualFileDir, schema) + val sparkRows = readRows.map { row => + TestRow.toSparkRow(row, schema) + } + val readDf = spark.createDataFrame(spark.sparkContext.parallelize(sparkRows), schema.toSpark) + readDf.createOrReplaceTempView("data_table") + statsColumnsAndTypes.foreach { case (col, dt) => + val realColName = col.getNames().mkString(".") + val statsColNameSuffix = col.getNames().mkString("__") + + if (ParquetStatsReader.isNullStatSupportedDataType(dt)) { + val actualNullCount = readDf.where(s"$realColName is null").count() + val statsNullCount = spark + .sql(s"select nullCount_$statsColNameSuffix from stats_table") + .collect()(0) + .getLong(0) + assert(actualNullCount == statsNullCount) + } + + if (ParquetStatsReader.isMinMaxStatSupportedDataType(dt)) { + spark.sql(s"""select min($realColName) as minVal, max($realColName) as maxVal + from data_table""").createOrReplaceTempView("min_max_table") + + // Verify that the real min/max values are equal to the reported min/max values. + val minAssertionDf = spark.sql(s""" + select first(minVal) from min_max_table d, stats_table s + where d.minVal = s.min_$statsColNameSuffix""") + assert(minAssertionDf.count() == 1) + + val maxAssertionDf = spark.sql(s""" + select first(maxVal) from min_max_table d, stats_table s + where d.maxVal = s.max_$statsColNameSuffix""") + assert(maxAssertionDf.count() == 1) + } + } + } + def verifyStatsUsingSpark( actualFileDir: String, actualFileStatuses: Seq[DataFileStatus], diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 171a769c333..cd533d8e51c 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -19,7 +19,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} import org.apache.spark.unsafe.types.VariantVal -import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row, VariantValue} import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue import io.delta.kernel.types._ @@ -47,6 +47,8 @@ import java.time.LocalDate */ class TestRow(val values: Array[Any]) { + def isNullAt(i: Int): Boolean = values(i) == null + def length: Int = values.length def get(i: Int): Any = values(i) @@ -175,6 +177,43 @@ object TestRow { }) } + def toSparkRow(row: TestRow, schema: StructType): SparkRow = + SparkRow(schema.fields.asScala.toSeq.zipWithIndex.map { case (field, i) => + if (row.isNullAt(i)) { + null + } else { + field.getDataType() match { + case BooleanType.BOOLEAN => row.get(i).asInstanceOf[Boolean] + case ByteType.BYTE => row.get(i).asInstanceOf[Byte] + case IntegerType.INTEGER => row.get(i).asInstanceOf[Int] + case LongType.LONG => row.get(i).asInstanceOf[Long] + case ShortType.SHORT => row.get(i).asInstanceOf[Short] + case DateType.DATE => + val rowVal = row.get(i).asInstanceOf[Int] + new java.sql.Date(java.util.concurrent.TimeUnit.DAYS.toMillis(rowVal)) + case TimestampType.TIMESTAMP => + val rowVal = row.get(i).asInstanceOf[Long] + new java.sql.Timestamp(java.util.concurrent.TimeUnit.MICROSECONDS.toMillis(rowVal)) + case FloatType.FLOAT => row.get(i).asInstanceOf[Float] + case DoubleType.DOUBLE => row.get(i).asInstanceOf[Double] + case StringType.STRING => row.get(i).asInstanceOf[String] + case BinaryType.BINARY => row.get(i).asInstanceOf[Array[Byte]] + case _: DecimalType => row.get(i).asInstanceOf[java.math.BigDecimal] + case at: ArrayType => row.get(i).asInstanceOf[Seq[Any]] + case _: MapType => row.get(i).asInstanceOf[Map[Any, Any]] + case st: StructType => + toSparkRow( + row.get(i).asInstanceOf[TestRow], + new StructType(field.getDataType().asInstanceOf[StructType].fields()) + ) + case VariantType.VARIANT => + val variantVal = row.get(i).asInstanceOf[VariantValue] + new VariantVal(variantVal.getValue(), variantVal.getMetadata()) + case _ => throw new UnsupportedOperationException("unrecognized data type") + } + } + }: _*) + /** * Retrieves the value at `rowId` in the column vector as it's corresponding scala type. * See the [[TestRow]] docs for details.