Skip to content

Commit 7b688e1

Browse files
committed
fix reads
1 parent 653b4ae commit 7b688e1

File tree

4 files changed

+76
-4
lines changed

4 files changed

+76
-4
lines changed

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ public DefaultVariantVector(
4949
ColumnVector value,
5050
ColumnVector metadata) {
5151
super(size, type, nullability);
52-
// checkArgument(offsets.length >= size + 1, "invalid offset array size");
5352
this.valueVector = requireNonNull(value, "value is null");
5453
this.metadataVector = requireNonNull(metadata, "metadata is null");
5554
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package io.delta.kernel.defaults.internal.parquet
1717

1818
import java.math.BigDecimal
1919

20+
import org.apache.spark.sql.DataFrame
21+
2022
import io.delta.golden.GoldenTableUtils.goldenTableFile
2123
import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow, VectorTestUtils}
2224
import io.delta.kernel.types._
@@ -141,4 +143,55 @@ class ParquetFileReaderSuite extends AnyFunSuite
141143

142144
checkAnswer(actResult2, expResult2)
143145
}
146+
147+
private def testReadVariant(testName: String)(df: => DataFrame): Unit = {
148+
test(testName) {
149+
withTable("test_variant_table") {
150+
df.write
151+
.format("delta")
152+
.mode("overwrite")
153+
.saveAsTable("test_variant_table")
154+
val path = spark.sql("describe table extended `test_variant_table`")
155+
.where("col_name = 'Location'")
156+
.collect()(0)
157+
.getString(1)
158+
.replace("file:", "")
159+
160+
val kernelSchema = tableSchema(path)
161+
val actResult = readParquetFilesUsingKernel(path, kernelSchema)
162+
val expResult = readParquetFilesUsingSpark(path, kernelSchema)
163+
checkAnswer(actResult, expResult)
164+
}
165+
}
166+
}
167+
168+
testReadVariant("basic read variant") {
169+
spark.range(0, 10, 1, 1).selectExpr(
170+
"parse_json(cast(id as string)) as basic_v",
171+
"named_struct('v', parse_json(cast(id as string))) as struct_v",
172+
"""array(
173+
parse_json(cast(id as string)),
174+
parse_json(cast(id as string)),
175+
parse_json(cast(id as string))
176+
) as array_v""",
177+
"map('test', parse_json(cast(id as string))) as map_value_v",
178+
"map(parse_json(cast(id as string)), parse_json(cast(id as string))) as map_key_v"
179+
)
180+
}
181+
182+
testReadVariant("basic null variant") {
183+
spark.range(0, 10, 1, 1).selectExpr(
184+
"cast(null as variant) basic_v",
185+
"named_struct('v', cast(null as variant)) as struct_v",
186+
"""array(
187+
parse_json(cast(id as string)),
188+
parse_json(cast(id as string)),
189+
null
190+
) as array_v""",
191+
"map('test', cast(null as variant)) as map_value_v",
192+
"map(cast(null as variant), parse_json(cast(id as string))) as map_key_v",
193+
)
194+
}
195+
196+
// TODO(richardc-db): Add nested variant tests once `parse_json` expression is implemented.
144197
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package io.delta.kernel.defaults.utils
1818
import scala.collection.JavaConverters._
1919
import org.apache.spark.sql.{types => sparktypes}
2020
import org.apache.spark.sql.{Row => SparkRow}
21+
import org.apache.spark.unsafe.types.VariantVal
2122
import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row}
2223
import io.delta.kernel.types._
2324

@@ -40,7 +41,7 @@ import java.time.LocalDate
4041
* - ArrayType --> Seq[Any]
4142
* - MapType --> Map[Any, Any]
4243
* - StructType --> TestRow
43-
*
44+
* - VariantType --> VariantVal
4445
* For complex types array and map, the inner elements types should align with this mapping.
4546
*/
4647
class TestRow(val values: Array[Any]) {
@@ -103,7 +104,9 @@ object TestRow {
103104
case _: ArrayType => arrayValueToScalaSeq(row.getArray(i))
104105
case _: MapType => mapValueToScalaMap(row.getMap(i))
105106
case _: StructType => TestRow(row.getStruct(i))
106-
case _: VariantType => row.getVariant(i)
107+
case _: VariantType =>
108+
val kernelVariant = row.getVariant(i)
109+
new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata())
107110
case _ => throw new UnsupportedOperationException("unrecognized data type")
108111
}
109112
}.toSeq)
@@ -134,6 +137,7 @@ object TestRow {
134137
decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v)
135138
}
136139
case _: sparktypes.StructType => TestRow(obj.asInstanceOf[SparkRow])
140+
case _: sparktypes.VariantType => obj.asInstanceOf[VariantVal]
137141
case _ => throw new UnsupportedOperationException("unrecognized data type")
138142
}
139143
}
@@ -164,7 +168,7 @@ object TestRow {
164168
decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v)
165169
}
166170
case _: sparktypes.StructType => TestRow(row.getStruct(i))
167-
case _: sparktypes.VariantType => row.getAs[Row](i)
171+
case _: sparktypes.VariantType => row.getAs[VariantVal](i)
168172
case _ => throw new UnsupportedOperationException("unrecognized data type")
169173
}
170174
})
@@ -195,6 +199,9 @@ object TestRow {
195199
TestRow.fromSeq(Seq.range(0, dataType.length()).map { ordinal =>
196200
getAsTestObject(vector.getChild(ordinal), rowId)
197201
})
202+
case _: VariantType =>
203+
val kernelVariant = vector.getVariant(rowId)
204+
new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata())
198205
case _ => throw new UnsupportedOperationException("unrecognized data type")
199206
}
200207
}

kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils
4141
import org.apache.spark.sql.SparkSession
4242
import org.apache.spark.sql.{types => sparktypes}
4343
import org.apache.spark.sql.catalyst.plans.SQLHelper
44+
import org.apache.spark.unsafe.types.VariantVal
4445
import org.scalatest.Assertions
4546

4647
trait TestUtils extends Assertions with SQLHelper {
@@ -117,6 +118,17 @@ trait TestUtils extends Assertions with SQLHelper {
117118
lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader
118119
}
119120

121+
/**
122+
* Drops table `tableName` after calling `f`.
123+
*/
124+
def withTable(tableNames: String*)(f: => Unit): Unit = {
125+
try f finally {
126+
tableNames.foreach { name =>
127+
spark.sql(s"DROP TABLE IF EXISTS $name")
128+
}
129+
}
130+
}
131+
120132
def withGoldenTable(tableName: String)(testFunc: String => Unit): Unit = {
121133
val tablePath = GoldenTableUtils.goldenTablePath(tableName)
122134
testFunc(tablePath)
@@ -396,6 +408,7 @@ trait TestUtils extends Assertions with SQLHelper {
396408
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
397409
case (a: Float, b: Float) =>
398410
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
411+
case (a: VariantVal, b: VariantVal) => a.debugString() == b.debugString()
399412
case (a, b) =>
400413
if (!a.equals(b)) {
401414
val sds = 200;

0 commit comments

Comments
 (0)