From 53fce552e58862b76cccbf4b363724dc218c0210 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 11 Mar 2024 23:14:53 -0700 Subject: [PATCH 1/3] init no tests yet --- .../internal/parquet/VariantConverter.java | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java new file mode 100644 index 00000000000..dc4b61d28be --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java @@ -0,0 +1,120 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.parquet; + +import java.util.*; + +import org.apache.parquet.io.api.Converter; +import org.apache.parquet.io.api.GroupConverter; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.VariantType; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + +import io.delta.kernel.defaults.internal.data.vector.DefaultVariantVector; +import io.delta.kernel.defaults.internal.parquet.ParquetConverters.BinaryColumnConverter; + +class VariantConverter + extends GroupConverter + implements ParquetConverters.BaseConverter { + private final BinaryColumnConverter valueConverter; + private final BinaryColumnConverter metadataConverter; + + // Working state + private boolean isCurrentValueNull = true; + private int currentRowIndex; + private boolean[] nullability; + + /** + * Create converter for {@link VariantType} column. + * + * @param initialBatchSize Estimate of initial row batch size. Used in memory allocations. + */ + VariantConverter(int initialBatchSize) { + checkArgument(initialBatchSize > 0, "invalid initialBatchSize: %s", initialBatchSize); + // Initialize the working state + this.nullability = ParquetConverters.initNullabilityVector(initialBatchSize); + + int parquetOrdinal = 0; + this.valueConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); + this.metadataConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); + } + + @Override + public Converter getConverter(int fieldIndex) { + checkArgument( + fieldIndex >= 0 && fieldIndex < 2, + "variant type is represented by a struct with 2 fields"); + if (fieldIndex == 0) { + return valueConverter; + } else { + return metadataConverter; + } + } + + @Override + public void start() { + isCurrentValueNull = false; + } + + @Override + public void end() { + } + + @Override + public void finalizeCurrentRow(long currentRowIndex) { + resizeIfNeeded(); + finalizeLastRowInConverters(currentRowIndex); + nullability[this.currentRowIndex] = isCurrentValueNull; + isCurrentValueNull = true; + + this.currentRowIndex++; + } + + public ColumnVector getDataColumnVector(int batchSize) { + ColumnVector vector = new DefaultVariantVector( + batchSize, + VariantType.VARIANT, + Optional.of(nullability), + valueConverter.getDataColumnVector(batchSize), + metadataConverter.getDataColumnVector(batchSize) + ); + resetWorkingState(); + return vector; + } + + @Override + public void resizeIfNeeded() { + if (nullability.length == currentRowIndex) { + int newSize = nullability.length * 2; + this.nullability = Arrays.copyOf(this.nullability, newSize); + ParquetConverters.setNullabilityToTrue(this.nullability, newSize / 2, newSize); + } + } + + @Override + public void resetWorkingState() { + this.currentRowIndex = 0; + this.isCurrentValueNull = true; + this.nullability = ParquetConverters.initNullabilityVector(this.nullability.length); + } + + private void finalizeLastRowInConverters(long prevRowIndex) { + valueConverter.finalizeCurrentRow(prevRowIndex); + metadataConverter.finalizeCurrentRow(prevRowIndex); + } +} From 4422020ed18924fdac9bdd897c40c76fa9c2e9ad Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 25 Mar 2024 16:25:25 -0700 Subject: [PATCH 2/3] init --- .../internal/parquet/ParquetStatsReader.java | 66 ++++++-- .../internal/parquet/VariantConverter.java | 120 -------------- .../io/delta/kernel/defaults/ScanSuite.scala | 38 +++++ .../parquet/ParquetFileWriterSuite.scala | 156 +++++++++++++++++- .../delta/kernel/defaults/utils/TestRow.scala | 41 ++++- .../kernel/defaults/utils/TestUtils.scala | 2 +- 6 files changed, 281 insertions(+), 142 deletions(-) delete mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java index 1e982ad25c2..4c6a8c7a451 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java @@ -92,33 +92,51 @@ private static DataFileStatistics constructFileStats( .collect(toImmutableMap( identity(), key -> mergeMetadataList(metadataForColumn.get(key)))); - Map minValues = new HashMap<>(); Map maxValues = new HashMap<>(); Map nullCounts = new HashMap<>(); for (Column statsColumn : statsColumns) { - Optional> stats = statsForColumn.get(statsColumn); DataType columnType = getDataType(dataSchema, statsColumn); - if (stats == null || !stats.isPresent() || !isStatsSupportedDataType(columnType)) { + + Optional> stats; + if (columnType instanceof VariantType) { + // Hack: Parquet stores Variant types as a struct with two binary child fields so + // the stats are also only stored on the child fields. Because the value and + // metadata fields are null iff the variant is null, the null count stat is + // retrieved by inspecting the "value" field. + int variantColNameLength = statsColumn.getNames().length; + String[] variantColNameList = + Arrays.copyOf(statsColumn.getNames(), variantColNameLength + 1); + variantColNameList[variantColNameLength] = "value"; + stats = statsForColumn.get(new Column(variantColNameList)); + } else { + stats = statsForColumn.get(statsColumn); + } + if (stats == null || !stats.isPresent()) { continue; } Statistics statistics = stats.get(); - Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; - nullCounts.put(statsColumn, numNulls); - - if (numNulls != null && rowCount == numNulls) { - // If all values are null, then min and max are also null - minValues.put(statsColumn, Literal.ofNull(columnType)); - maxValues.put(statsColumn, Literal.ofNull(columnType)); - continue; + if (isNullStatSupportedDataType(columnType)) { + Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; + nullCounts.put(statsColumn, numNulls); } - Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */); - minValues.put(statsColumn, minValue); + if (isMinMaxStatSupportedDataType(columnType)) { + Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; + if (numNulls != null && rowCount == numNulls) { + // If all values are null, then min and max are also null + minValues.put(statsColumn, Literal.ofNull(columnType)); + maxValues.put(statsColumn, Literal.ofNull(columnType)); + continue; + } + + Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */); + minValues.put(statsColumn, minValue); - Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */); - maxValues.put(statsColumn, maxValue); + Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */); + maxValues.put(statsColumn, maxValue); + } } return new DataFileStatistics(rowCount, minValues, maxValues, nullCounts); @@ -220,7 +238,7 @@ private static boolean hasInvalidStatistics(Collection meta }); } - private static boolean isStatsSupportedDataType(DataType dataType) { + public static boolean isMinMaxStatSupportedDataType(DataType dataType) { return dataType instanceof BooleanType || dataType instanceof ByteType || dataType instanceof ShortType || @@ -236,6 +254,22 @@ private static boolean isStatsSupportedDataType(DataType dataType) { // Add support later. } + public static boolean isNullStatSupportedDataType(DataType dataType) { + return dataType instanceof BooleanType || + dataType instanceof ByteType || + dataType instanceof ShortType || + dataType instanceof IntegerType || + dataType instanceof LongType || + dataType instanceof FloatType || + dataType instanceof DoubleType || + dataType instanceof DecimalType || + dataType instanceof DateType || + dataType instanceof StringType || + dataType instanceof BinaryType || + dataType instanceof VariantType; + // TODO: add timestamp support later. + } + private static byte[] getBinaryStat(Statistics statistics, boolean decodeMin) { return decodeMin ? statistics.getMinBytes() : statistics.getMaxBytes(); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java deleted file mode 100644 index dc4b61d28be..00000000000 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (2024) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults.internal.parquet; - -import java.util.*; - -import org.apache.parquet.io.api.Converter; -import org.apache.parquet.io.api.GroupConverter; - -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.types.BinaryType; -import io.delta.kernel.types.VariantType; -import static io.delta.kernel.internal.util.Preconditions.checkArgument; - -import io.delta.kernel.defaults.internal.data.vector.DefaultVariantVector; -import io.delta.kernel.defaults.internal.parquet.ParquetConverters.BinaryColumnConverter; - -class VariantConverter - extends GroupConverter - implements ParquetConverters.BaseConverter { - private final BinaryColumnConverter valueConverter; - private final BinaryColumnConverter metadataConverter; - - // Working state - private boolean isCurrentValueNull = true; - private int currentRowIndex; - private boolean[] nullability; - - /** - * Create converter for {@link VariantType} column. - * - * @param initialBatchSize Estimate of initial row batch size. Used in memory allocations. - */ - VariantConverter(int initialBatchSize) { - checkArgument(initialBatchSize > 0, "invalid initialBatchSize: %s", initialBatchSize); - // Initialize the working state - this.nullability = ParquetConverters.initNullabilityVector(initialBatchSize); - - int parquetOrdinal = 0; - this.valueConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); - this.metadataConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); - } - - @Override - public Converter getConverter(int fieldIndex) { - checkArgument( - fieldIndex >= 0 && fieldIndex < 2, - "variant type is represented by a struct with 2 fields"); - if (fieldIndex == 0) { - return valueConverter; - } else { - return metadataConverter; - } - } - - @Override - public void start() { - isCurrentValueNull = false; - } - - @Override - public void end() { - } - - @Override - public void finalizeCurrentRow(long currentRowIndex) { - resizeIfNeeded(); - finalizeLastRowInConverters(currentRowIndex); - nullability[this.currentRowIndex] = isCurrentValueNull; - isCurrentValueNull = true; - - this.currentRowIndex++; - } - - public ColumnVector getDataColumnVector(int batchSize) { - ColumnVector vector = new DefaultVariantVector( - batchSize, - VariantType.VARIANT, - Optional.of(nullability), - valueConverter.getDataColumnVector(batchSize), - metadataConverter.getDataColumnVector(batchSize) - ); - resetWorkingState(); - return vector; - } - - @Override - public void resizeIfNeeded() { - if (nullability.length == currentRowIndex) { - int newSize = nullability.length * 2; - this.nullability = Arrays.copyOf(this.nullability, newSize); - ParquetConverters.setNullabilityToTrue(this.nullability, newSize / 2, newSize); - } - } - - @Override - public void resetWorkingState() { - this.currentRowIndex = 0; - this.isCurrentValueNull = true; - this.nullability = ParquetConverters.initNullabilityVector(this.nullability.length); - } - - private void finalizeLastRowInConverters(long prevRowIndex) { - valueConverter.finalizeCurrentRow(prevRowIndex); - metadataConverter.finalizeCurrentRow(prevRowIndex); - } -} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index e08adaadbde..1ba20d2b91a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -1386,6 +1386,44 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with } } + test("data skipping - basic variant type") { + withTable("test_table") { + spark.range(0, 10, 1, 1) + .selectExpr( + "parse_json(cast(id as string)) as basic_v", + "cast(null as variant) as basic_null_v", + "named_struct('v', parse_json(cast(id as string))) as basic_struct_v", + "named_struct('v', cast(null as variant)) as basic_struct_null_v", + ) + .write + .format("delta") + .mode("overwrite") + .saveAsTable("test_table") + + val filePath = spark.sql("describe table extended `test_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + checkSkipping( + filePath, + hits = Seq( + isNotNull(col("basic_v")), + not(isNotNull(col("basic_null_v"))), + isNotNull(nestedCol("basic_struct_v.v")), + not(isNotNull(nestedCol("basic_struct_null_v.v"))) + ), + misses = Seq( + not(isNotNull(col("basic_v"))), + isNotNull(col("basic_null_v")), + not(isNotNull(nestedCol("basic_struct_v.v"))), + isNotNull(nestedCol("basic_struct_null_v.v")) + ) + ) + } + } + test("data skipping - is not null with DVs in file with non-nulls") { withSQLConf(("spark.databricks.delta.properties.defaults.enableDeletionVectors", "true")) { withTempDir { tempDir => diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala index d6d5ccb10cf..f912fbd78a9 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala @@ -16,8 +16,12 @@ package io.delta.kernel.defaults.internal.parquet import java.lang.{Double => DoubleJ, Float => FloatJ} +import java.util.ArrayList +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row => SparkRow} import io.delta.golden.GoldenTableUtils.{goldenTableFile, goldenTablePath} import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch} @@ -191,7 +195,9 @@ class ParquetFileWriterSuite extends AnyFunSuite } } - def testWrite(testName: String)(df: => DataFrame): Unit = { + def testWrite( + testName: String, + statsColsAndKernelType: Seq[(Column, DataType)])(df: => DataFrame): Unit = { test(testName) { withTable("test_table") { withTempDir { writeDir => @@ -215,14 +221,68 @@ class ParquetFileWriterSuite extends AnyFunSuite val readData = readParquetUsingKernelAsColumnarBatches(filePath, physicalSchema) .map(_.toFiltered(Option.empty[Predicate])) val writePath = writeDir.getAbsolutePath - val writeOutput = writeToParquetUsingKernel(readData, writePath) + val writeOutput = writeToParquetUsingKernel( + readData, + writePath, + statsColumns = statsColsAndKernelType.map { case (col, dt) => col} + ) verifyContentUsingKernelReader(writePath, readData) + + val numMinMaxStatCols = statsColsAndKernelType.count { case (_, dt) => + ParquetStatsReader.isMinMaxStatSupportedDataType(dt) + } + val numNullStatCols = statsColsAndKernelType.count { case (_, dt) => + ParquetStatsReader.isNullStatSupportedDataType(dt) + } + verifyVariantStatsUsingKernel( + writePath, + writeOutput, + schema, + statsColsAndKernelType, + numMinMaxStatCols, + numNullStatCols + ) } } } } - testWrite("basic write variant") { + testWrite( + "basic write verify stats with kernel", + Seq( + (new Column("boolcol"), BooleanType.BOOLEAN), + (new Column("bytecol"), ByteType.BYTE), + (new Column("intcol"), IntegerType.INTEGER), + (new Column("longcol"), LongType.LONG), + (new Column("shortcol"), ShortType.SHORT), + (new Column("floatcol"), FloatType.FLOAT), + (new Column("doublecol"), DoubleType.DOUBLE), + (new Column("stringcol"), StringType.STRING), + (new Column("binarycol"), BinaryType.BINARY), + (new Column("decimalcol"), DecimalType.USER_DEFAULT), + (new Column("datecol"), DateType.DATE) + )) { + spark.range(0, 100, 1, 1).selectExpr( + "cast(id as boolean) as boolcol", + "cast(id as byte) as bytecol", + "cast(5 as int) as intcol", + "cast(id as long) as longcol", + "cast(id as short) as shortcol", + "cast(id as float) as floatcol", + "cast(id as double) as doublecol", + "cast(id as string) as stringcol", + "cast(id as binary) as binarycol", + "cast(id as decimal(10, 0)) as decimalcol", + "current_date() as datecol" + ) + } + + testWrite( + "basic write variant", + Seq( + (new Column("basic_v"), VariantType.VARIANT), + (new Column(Array("struct_v", "v")), VariantType.VARIANT) + )) { spark.range(0, 10, 1, 1).selectExpr( "parse_json(cast(id as string)) as basic_v", "named_struct('v', parse_json(cast(id as string))) as struct_v", @@ -236,7 +296,12 @@ class ParquetFileWriterSuite extends AnyFunSuite ) } - testWrite("basic write null variant") { + testWrite( + "basic write null variant", + Seq( + (new Column("basic_v"), VariantType.VARIANT), + (new Column(Array("struct_v", "v")), VariantType.VARIANT) + )) { spark.range(0, 10, 1, 1).selectExpr( "cast(null as variant) basic_v", "named_struct('v', cast(null as variant)) as struct_v", @@ -372,6 +437,89 @@ class ParquetFileWriterSuite extends AnyFunSuite } } + def verifyVariantStatsUsingKernel( + actualFileDir: String, + actualFileStatuses: Seq[DataFileStatus], + schema: StructType, + statsColumnsAndTypes: Seq[(Column, DataType)], + expMinMaxStatsCount: Int, + expNullCountStatsCount: Int): Unit = { + if (statsColumnsAndTypes.isEmpty) return + + val statFields = ArrayBuffer( + new StructField("location", StringType.STRING, true), + new StructField("fileSize", LongType.LONG, true), + new StructField("lastModifiedTime", LongType.LONG, true), + new StructField("rowCount", LongType.LONG, true) + ) + + val statsColumns = statsColumnsAndTypes.map { _._1 } + statsColumnsAndTypes.foreach { case (column, dt) => + val colName = column.getNames().toSeq.mkString("__") + statFields ++= ArrayBuffer( + new StructField(s"min_$colName", dt, true), + new StructField(s"max_$colName", dt, true), + new StructField(s"nullCount_$colName", LongType.LONG, true) + ) + } + + val kernelStatsSchema = new StructType(new ArrayList(statFields.asJava)) + + val actualStatsOutput = actualFileStatuses + .map { fileStatus => + // validate there are the expected number of stats columns + assert(fileStatus.getStatistics.isPresent) + assert(fileStatus.getStatistics.get().getMinValues.size() === expMinMaxStatsCount) + assert(fileStatus.getStatistics.get().getMaxValues.size() === expMinMaxStatsCount) + assert(fileStatus.getStatistics.get().getNullCounts.size() === expNullCountStatsCount) + + // Convert to Spark row for comparison with the actual values computing using Spark. + TestRow.toSparkRow( + fileStatus.toTestRow(statsColumns), + kernelStatsSchema) + } + + val statDf = spark.createDataFrame( + spark.sparkContext.parallelize(actualStatsOutput), + kernelStatsSchema.toSpark).createOrReplaceTempView("stats_table") + + val readRows = readParquetFilesUsingKernel(actualFileDir, schema) + val sparkRows = readRows.map { row => + TestRow.toSparkRow(row, schema) + } + val readDf = spark.createDataFrame(spark.sparkContext.parallelize(sparkRows), schema.toSpark) + readDf.createOrReplaceTempView("data_table") + statsColumnsAndTypes.foreach { case (col, dt) => + val realColName = col.getNames().mkString(".") + val statsColNameSuffix = col.getNames().mkString("__") + + if (ParquetStatsReader.isNullStatSupportedDataType(dt)) { + val actualNullCount = readDf.where(s"$realColName is null").count() + val statsNullCount = spark + .sql(s"select nullCount_$statsColNameSuffix from stats_table") + .collect()(0) + .getLong(0) + assert(actualNullCount == statsNullCount) + } + + if (ParquetStatsReader.isMinMaxStatSupportedDataType(dt)) { + spark.sql(s"""select min($realColName) as minVal, max($realColName) as maxVal + from data_table""").createOrReplaceTempView("min_max_table") + + // Verify that the real min/max values are equal to the reported min/max values. + val minAssertionDf = spark.sql(s""" + select first(minVal) from min_max_table d, stats_table s + where d.minVal = s.min_$statsColNameSuffix""") + assert(minAssertionDf.count() == 1) + + val maxAssertionDf = spark.sql(s""" + select first(maxVal) from min_max_table d, stats_table s + where d.maxVal = s.max_$statsColNameSuffix""") + assert(maxAssertionDf.count() == 1) + } + } + } + def verifyStatsUsingSpark( actualFileDir: String, actualFileStatuses: Seq[DataFileStatus], diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 171a769c333..cd533d8e51c 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -19,7 +19,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} import org.apache.spark.unsafe.types.VariantVal -import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row, VariantValue} import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue import io.delta.kernel.types._ @@ -47,6 +47,8 @@ import java.time.LocalDate */ class TestRow(val values: Array[Any]) { + def isNullAt(i: Int): Boolean = values(i) == null + def length: Int = values.length def get(i: Int): Any = values(i) @@ -175,6 +177,43 @@ object TestRow { }) } + def toSparkRow(row: TestRow, schema: StructType): SparkRow = + SparkRow(schema.fields.asScala.toSeq.zipWithIndex.map { case (field, i) => + if (row.isNullAt(i)) { + null + } else { + field.getDataType() match { + case BooleanType.BOOLEAN => row.get(i).asInstanceOf[Boolean] + case ByteType.BYTE => row.get(i).asInstanceOf[Byte] + case IntegerType.INTEGER => row.get(i).asInstanceOf[Int] + case LongType.LONG => row.get(i).asInstanceOf[Long] + case ShortType.SHORT => row.get(i).asInstanceOf[Short] + case DateType.DATE => + val rowVal = row.get(i).asInstanceOf[Int] + new java.sql.Date(java.util.concurrent.TimeUnit.DAYS.toMillis(rowVal)) + case TimestampType.TIMESTAMP => + val rowVal = row.get(i).asInstanceOf[Long] + new java.sql.Timestamp(java.util.concurrent.TimeUnit.MICROSECONDS.toMillis(rowVal)) + case FloatType.FLOAT => row.get(i).asInstanceOf[Float] + case DoubleType.DOUBLE => row.get(i).asInstanceOf[Double] + case StringType.STRING => row.get(i).asInstanceOf[String] + case BinaryType.BINARY => row.get(i).asInstanceOf[Array[Byte]] + case _: DecimalType => row.get(i).asInstanceOf[java.math.BigDecimal] + case at: ArrayType => row.get(i).asInstanceOf[Seq[Any]] + case _: MapType => row.get(i).asInstanceOf[Map[Any, Any]] + case st: StructType => + toSparkRow( + row.get(i).asInstanceOf[TestRow], + new StructType(field.getDataType().asInstanceOf[StructType].fields()) + ) + case VariantType.VARIANT => + val variantVal = row.get(i).asInstanceOf[VariantValue] + new VariantVal(variantVal.getValue(), variantVal.getMetadata()) + case _ => throw new UnsupportedOperationException("unrecognized data type") + } + } + }: _*) + /** * Retrieves the value at `rowId` in the column vector as it's corresponding scala type. * See the [[TestRow]] docs for details. diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index a14d2497bec..c0689b3593d 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -605,7 +605,7 @@ trait TestUtils extends Assertions with SQLHelper { /** * Converts a Delta DataType to a Spark DataType. */ - private def toSparkType(deltaType: DataType): sparktypes.DataType = { + def toSparkType(deltaType: DataType): sparktypes.DataType = { deltaType match { case BooleanType.BOOLEAN => sparktypes.DataTypes.BooleanType case ByteType.BYTE => sparktypes.DataTypes.ByteType From 85733ab84db2467750a784fe6824fbd90d3e3d0c Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 25 Mar 2024 16:54:13 -0700 Subject: [PATCH 3/3] add back private --- .../test/scala/io/delta/kernel/defaults/utils/TestUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index c0689b3593d..a14d2497bec 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -605,7 +605,7 @@ trait TestUtils extends Assertions with SQLHelper { /** * Converts a Delta DataType to a Spark DataType. */ - def toSparkType(deltaType: DataType): sparktypes.DataType = { + private def toSparkType(deltaType: DataType): sparktypes.DataType = { deltaType match { case BooleanType.BOOLEAN => sparktypes.DataTypes.BooleanType case ByteType.BYTE => sparktypes.DataTypes.ByteType