Skip to content

Commit 4ccd82f

Browse files
committed
fix reads
1 parent fc4d8a1 commit 4ccd82f

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ public DefaultVariantVector(
4949
ColumnVector value,
5050
ColumnVector metadata) {
5151
super(size, type, nullability);
52-
// checkArgument(offsets.length >= size + 1, "invalid offset array size");
5352
this.valueVector = requireNonNull(value, "value is null");
5453
this.metadataVector = requireNonNull(metadata, "metadata is null");
5554
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
package io.delta.kernel.defaults.internal.parquet
1717

1818
import java.math.BigDecimal
19+
20+
import org.apache.spark.sql.DataFrame
21+
1922
import io.delta.golden.GoldenTableUtils.goldenTableFile
2023
import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow, VectorTestUtils}
2124
import io.delta.kernel.types._
@@ -139,4 +142,55 @@ class ParquetFileReaderSuite extends AnyFunSuite
139142

140143
checkAnswer(actResult2, expResult2)
141144
}
145+
146+
private def testReadVariant(testName: String)(df: => DataFrame): Unit = {
147+
test(testName) {
148+
withTable("test_variant_table") {
149+
df.write
150+
.format("delta")
151+
.mode("overwrite")
152+
.saveAsTable("test_variant_table")
153+
val path = spark.sql("describe table extended `test_variant_table`")
154+
.where("col_name = 'Location'")
155+
.collect()(0)
156+
.getString(1)
157+
.replace("file:", "")
158+
159+
val kernelSchema = tableSchema(path)
160+
val actResult = readParquetFilesUsingKernel(path, kernelSchema)
161+
val expResult = readParquetFilesUsingSpark(path, kernelSchema)
162+
checkAnswer(actResult, expResult)
163+
}
164+
}
165+
}
166+
167+
testReadVariant("basic read variant") {
168+
spark.range(0, 10, 1, 1).selectExpr(
169+
"parse_json(cast(id as string)) as basic_v",
170+
"named_struct('v', parse_json(cast(id as string))) as struct_v",
171+
"""array(
172+
parse_json(cast(id as string)),
173+
parse_json(cast(id as string)),
174+
parse_json(cast(id as string))
175+
) as array_v""",
176+
"map('test', parse_json(cast(id as string))) as map_value_v",
177+
"map(parse_json(cast(id as string)), parse_json(cast(id as string))) as map_key_v"
178+
)
179+
}
180+
181+
testReadVariant("basic null variant") {
182+
spark.range(0, 10, 1, 1).selectExpr(
183+
"cast(null as variant) basic_v",
184+
"named_struct('v', cast(null as variant)) as struct_v",
185+
"""array(
186+
parse_json(cast(id as string)),
187+
parse_json(cast(id as string)),
188+
null
189+
) as array_v""",
190+
"map('test', cast(null as variant)) as map_value_v",
191+
"map(cast(null as variant), parse_json(cast(id as string))) as map_key_v",
192+
)
193+
}
194+
195+
// TODO(richardc-db): Add nested variant tests once `parse_json` expression is implemented.
142196
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package io.delta.kernel.defaults.utils
1818
import scala.collection.JavaConverters._
1919
import org.apache.spark.sql.{types => sparktypes}
2020
import org.apache.spark.sql.{Row => SparkRow}
21+
import org.apache.spark.unsafe.types.VariantVal
2122
import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row}
2223
import io.delta.kernel.types._
2324

@@ -44,7 +45,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset}
4445
* - ArrayType --> Seq[Any]
4546
* - MapType --> Map[Any, Any]
4647
* - StructType --> TestRow
47-
*
48+
* - VariantType --> VariantVal
4849
* For complex types array and map, the inner elements types should align with this mapping.
4950
*/
5051
class TestRow(val values: Array[Any]) {
@@ -108,7 +109,9 @@ object TestRow {
108109
case _: ArrayType => arrayValueToScalaSeq(row.getArray(i))
109110
case _: MapType => mapValueToScalaMap(row.getMap(i))
110111
case _: StructType => TestRow(row.getStruct(i))
111-
case _: VariantType => row.getVariant(i)
112+
case _: VariantType =>
113+
val kernelVariant = row.getVariant(i)
114+
new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata())
112115
case _ => throw new UnsupportedOperationException("unrecognized data type")
113116
}
114117
}.toSeq)
@@ -141,6 +144,7 @@ object TestRow {
141144
decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v)
142145
}
143146
case _: sparktypes.StructType => TestRow(obj.asInstanceOf[SparkRow])
147+
case _: sparktypes.VariantType => obj.asInstanceOf[VariantVal]
144148
case _ => throw new UnsupportedOperationException("unrecognized data type")
145149
}
146150
}
@@ -174,7 +178,7 @@ object TestRow {
174178
decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v)
175179
}
176180
case _: sparktypes.StructType => TestRow(row.getStruct(i))
177-
case _: sparktypes.VariantType => row.getAs[Row](i)
181+
case _: sparktypes.VariantType => row.getAs[VariantVal](i)
178182
case _ => throw new UnsupportedOperationException("unrecognized data type")
179183
}
180184
})
@@ -206,6 +210,9 @@ object TestRow {
206210
TestRow.fromSeq(Seq.range(0, dataType.length()).map { ordinal =>
207211
getAsTestObject(vector.getChild(ordinal), rowId)
208212
})
213+
case _: VariantType =>
214+
val kernelVariant = vector.getVariant(rowId)
215+
new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata())
209216
case _ => throw new UnsupportedOperationException("unrecognized data type")
210217
}
211218
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils
4141
import org.apache.spark.sql.SparkSession
4242
import org.apache.spark.sql.{types => sparktypes}
4343
import org.apache.spark.sql.catalyst.plans.SQLHelper
44+
import org.apache.spark.unsafe.types.VariantVal
4445
import org.scalatest.Assertions
4546

4647
trait TestUtils extends Assertions with SQLHelper {
@@ -117,6 +118,17 @@ trait TestUtils extends Assertions with SQLHelper {
117118
lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader
118119
}
119120

121+
/**
122+
* Drops table `tableName` after calling `f`.
123+
*/
124+
def withTable(tableNames: String*)(f: => Unit): Unit = {
125+
try f finally {
126+
tableNames.foreach { name =>
127+
spark.sql(s"DROP TABLE IF EXISTS $name")
128+
}
129+
}
130+
}
131+
120132
def withGoldenTable(tableName: String)(testFunc: String => Unit): Unit = {
121133
val tablePath = GoldenTableUtils.goldenTablePath(tableName)
122134
testFunc(tablePath)
@@ -404,6 +416,7 @@ trait TestUtils extends Assertions with SQLHelper {
404416
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
405417
case (a: Float, b: Float) =>
406418
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
419+
case (a: VariantVal, b: VariantVal) => a.debugString() == b.debugString()
407420
case (a, b) =>
408421
if (!a.equals(b)) {
409422
val sds = 200;

0 commit comments

Comments
 (0)