From 8c7131e0f36e08a1ea0f7243281cb4c9491dcccb Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 11 Mar 2024 22:20:22 -0700 Subject: [PATCH 01/16] minimal support --- .../apache/spark/sql/delta/TableFeature.scala | 11 +- .../spark/sql/delta/schema/SchemaUtils.scala | 8 ++ .../spark/sql/delta/util/PartitionUtils.scala | 2 +- .../spark/sql/delta/DeltaVariantSuite.scala | 112 ++++++++++++++++++ 4 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala b/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala index f58666d946a..edaf646b396 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala @@ -359,7 +359,8 @@ object TableFeature { // managed-commits are under development and only available in testing. ManagedCommitTableFeature, InCommitTimestampTableFeature, - TypeWideningTableFeature) + TypeWideningTableFeature, + VariantTypeTableFeature) } val featureMap = features.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap require(features.size == featureMap.size, "Lowercase feature names must not duplicate.") @@ -494,6 +495,14 @@ object IdentityColumnsTableFeature } } +object VariantTypeTableFeature extends ReaderWriterFeature(name = "variantType-dev") + with FeatureAutomaticallyEnabledByMetadata { + override def metadataRequiresFeatureToBeEnabled( + metadata: Metadata, spark: SparkSession): Boolean = { + SchemaUtils.checkForVariantTypeColumnsRecursively(metadata.schema) + } +} + object TimestampNTZTableFeature extends ReaderWriterFeature(name = "timestampNtz") with FeatureAutomaticallyEnabledByMetadata { override def metadataRequiresFeatureToBeEnabled( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index c8dd66af470..37d5ad5a2c8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -1263,6 +1263,13 @@ def normalizeColumnNamesInDataType( unsupportedDataTypes.toSeq } + /** + * Find VariantType columns in the table schema. + */ + def checkForVariantTypeColumnsRecursively(schema: StructType): Boolean = { + SchemaUtils.typeExistsRecursively(schema)(_.isInstanceOf[VariantType]) + } + /** * Find TimestampNTZ columns in the table schema. */ @@ -1302,6 +1309,7 @@ def normalizeColumnNamesInDataType( case DateType => case TimestampType => case TimestampNTZType => + case VariantType => case BinaryType => case _: DecimalType => case a: ArrayType => diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala index 6f870e30ca5..698f89904b3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala @@ -605,7 +605,7 @@ private[delta] object PartitionUtils { partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { field => field.dataType match { - case _: AtomicType => // OK + case a: AtomicType if !a.isInstanceOf[VariantType] => // OK case _ => throw DeltaErrors.cannotUseDataTypeForPartitionColumnError(field) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala new file mode 100644 index 00000000000..1334bfa6923 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala @@ -0,0 +1,112 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils} +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.types.StructType + +class DeltaVariantSuite + extends QueryTest + with DeltaSQLCommandTest { + + private def getProtocolForTable(table: String): Protocol = { + val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table)) + deltaLog.unsafeVolatileSnapshot.protocol + } + + test("create a new table with Variant, higher protocol and feature should be picked.") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, v VARIANT) USING DELTA") + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + // TODO(r.chen): Enable once `parse_json` is properly implemented in OSS Spark. + // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion.withFeature(VariantTypeTableFeature) + ) + } + } + + test("creating a table without Variant should use the usual minimum protocol") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, i INTEGER) USING DELTA") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + assert( + !deltaLog.unsafeVolatileSnapshot.protocol.isFeatureSupported(VariantTypeTableFeature), + s"Table tbl contains VariantTypeFeature descriptor when its not supposed to" + ) + } + } + + test("add a new Variant column should upgrade to the correct protocol versions") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING) USING delta") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + // Should throw error + val e = intercept[SparkThrowable] { + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + } + // capture the existing protocol here. + // we will check the error message later in this test as we need to compare the + // expected schema and protocol + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + val currentProtocol = deltaLog.unsafeVolatileSnapshot.protocol + val currentFeatures = currentProtocol.implicitlyAndExplicitlySupportedFeatures + .map(_.name) + .toSeq + .sorted + .mkString(", ") + + // add table feature + sql( + s"ALTER TABLE tbl " + + s"SET TBLPROPERTIES('delta.feature.variantType-dev' = 'supported')" + ) + + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + + // check previously thrown error message + checkError( + e, + errorClass = "DELTA_FEATURES_REQUIRE_MANUAL_ENABLEMENT", + parameters = Map( + "unsupportedFeatures" -> VariantTypeTableFeature.name, + "supportedFeatures" -> currentFeatures + ) + ) + + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + // TODO(r.chen): Enable once `parse_json` is properly implemented in OSS Spark. + // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion + .withFeature(VariantTypeTableFeature) + .withFeature(InvariantsTableFeature) + .withFeature(AppendOnlyTableFeature) + ) + } + } +} From a84e8f4f03599b92ad4be47783653fd2340c5e66 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Tue, 19 Mar 2024 11:37:01 -0700 Subject: [PATCH 02/16] test --- .../scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala index 1334bfa6923..b19e29b52fa 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala @@ -36,7 +36,7 @@ class DeltaVariantSuite withTable("tbl") { sql("CREATE TABLE tbl(s STRING, v VARIANT) USING DELTA") sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") - // TODO(r.chen): Enable once `parse_json` is properly implemented in OSS Spark. + // TODO(r.chen): Enable once variant casting is properly implemented in OSS Spark. // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) assert( getProtocolForTable("tbl") == @@ -97,7 +97,7 @@ class DeltaVariantSuite ) sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") - // TODO(r.chen): Enable once `parse_json` is properly implemented in OSS Spark. + // TODO(r.chen): Enable once variant casting is properly implemented in OSS Spark. // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) assert( From 5c01b9a54cc6afed858e05833b357c86e3e5fc2f Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 11 Apr 2024 22:19:25 -0700 Subject: [PATCH 03/16] cross compile --- build.sbt | 2 + .../scala-spark-3.5/shims/VariantShim.scala | 26 +++++ .../shims/VariantShim.scala | 23 ++++ .../spark/sql/delta/schema/SchemaUtils.scala | 4 +- .../spark/sql/delta/util/PartitionUtils.scala | 2 +- .../shims/DeltaVariantSparkOnlyTests.scala | 19 +++ .../shims/DeltaVariantSparkOnlyTests.scala | 109 ++++++++++++++++++ .../spark/sql/delta/DeltaVariantSuite.scala | 95 +-------------- 8 files changed, 183 insertions(+), 97 deletions(-) create mode 100644 spark/src/main/scala-spark-3.5/shims/VariantShim.scala create mode 100644 spark/src/main/scala-spark-master/shims/VariantShim.scala create mode 100644 spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala create mode 100644 spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala diff --git a/build.sbt b/build.sbt index 508cc4b5dfc..3e79b965754 100644 --- a/build.sbt +++ b/build.sbt @@ -138,6 +138,7 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { // For adding staged Spark RC versions, e.g.: // resolvers += "Apache Spark 3.5.0 (RC1) Staging" at "https://repository.apache.org/content/repositories/orgapachespark-1444/", Compile / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "main" / "scala-spark-3.5", + Test / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "test" / "scala-spark-3.5", Antlr4 / antlr4Version := "4.9.3", // Java-/Scala-/Uni-Doc Settings @@ -153,6 +154,7 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { targetJvm := "17", resolvers += "Spark master staging" at "https://repository.apache.org/content/groups/snapshots/", Compile / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "main" / "scala-spark-master", + Test / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "test" / "scala-spark-master", Antlr4 / antlr4Version := "4.13.1", Test / javaOptions ++= Seq( // Copied from SparkBuild.scala to support Java 17 for unit tests (see apache/spark#34153) diff --git a/spark/src/main/scala-spark-3.5/shims/VariantShim.scala b/spark/src/main/scala-spark-3.5/shims/VariantShim.scala new file mode 100644 index 00000000000..d802146eba1 --- /dev/null +++ b/spark/src/main/scala-spark-3.5/shims/VariantShim.scala @@ -0,0 +1,26 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +object VariantShim { + + /** + * Spark's variant type is implemented for Spark 4.0 and is not implemented in Spark 3.5. Thus, + * any Spark 3.5 DataType cannot be a variant type. + */ + def isTypeVariant(dt: DataType): Boolean = false +} diff --git a/spark/src/main/scala-spark-master/shims/VariantShim.scala b/spark/src/main/scala-spark-master/shims/VariantShim.scala new file mode 100644 index 00000000000..7853eb593c6 --- /dev/null +++ b/spark/src/main/scala-spark-master/shims/VariantShim.scala @@ -0,0 +1,23 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +object VariantShim { + + /** Spark's variant type is only implemented in Spark 4.0 and above.*/ + def isTypeVariant(dt: DataType): Boolean = dt.isInstanceOf[VariantType] +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index 37d5ad5a2c8..dde29afbbd9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -1267,7 +1267,7 @@ def normalizeColumnNamesInDataType( * Find VariantType columns in the table schema. */ def checkForVariantTypeColumnsRecursively(schema: StructType): Boolean = { - SchemaUtils.typeExistsRecursively(schema)(_.isInstanceOf[VariantType]) + SchemaUtils.typeExistsRecursively(schema)(VariantShim.isTypeVariant(_)) } /** @@ -1309,7 +1309,7 @@ def normalizeColumnNamesInDataType( case DateType => case TimestampType => case TimestampNTZType => - case VariantType => + case dt if VariantShim.isTypeVariant(dt) => case BinaryType => case _: DecimalType => case a: ArrayType => diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala index 698f89904b3..87f72fc1a59 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/PartitionUtils.scala @@ -605,7 +605,7 @@ private[delta] object PartitionUtils { partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { field => field.dataType match { - case a: AtomicType if !a.isInstanceOf[VariantType] => // OK + case a: AtomicType if !VariantShim.isTypeVariant(a) => // OK case _ => throw DeltaErrors.cannotUseDataTypeForPartitionColumnError(field) } } diff --git a/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala b/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala new file mode 100644 index 00000000000..a7edc237bca --- /dev/null +++ b/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala @@ -0,0 +1,19 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +trait DeltaVariantSparkOnlyTests { self: DeltaVariantSuite => } diff --git a/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala b/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala new file mode 100644 index 00000000000..09ddabbee48 --- /dev/null +++ b/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala @@ -0,0 +1,109 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils} +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.types.StructType + +trait DeltaVariantSparkOnlyTests + extends QueryTest + with DeltaSQLCommandTest { self: DeltaVariantSuite => + private def getProtocolForTable(table: String): Protocol = { + val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table)) + deltaLog.unsafeVolatileSnapshot.protocol + } + + test("create a new table with Variant, higher protocol and feature should be picked.") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, v VARIANT) USING DELTA") + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion.withFeature(VariantTypeTableFeature) + ) + } + } + + test("creating a table without Variant should use the usual minimum protocol") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, i INTEGER) USING DELTA") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + assert( + !deltaLog.unsafeVolatileSnapshot.protocol.isFeatureSupported(VariantTypeTableFeature), + s"Table tbl contains VariantTypeFeature descriptor when its not supposed to" + ) + } + } + + test("add a new Variant column should upgrade to the correct protocol versions") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING) USING delta") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + // Should throw error + val e = intercept[SparkThrowable] { + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + } + // capture the existing protocol here. + // we will check the error message later in this test as we need to compare the + // expected schema and protocol + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + val currentProtocol = deltaLog.unsafeVolatileSnapshot.protocol + val currentFeatures = currentProtocol.implicitlyAndExplicitlySupportedFeatures + .map(_.name) + .toSeq + .sorted + .mkString(", ") + + // add table feature + sql( + s"ALTER TABLE tbl " + + s"SET TBLPROPERTIES('delta.feature.variantType-dev' = 'supported')" + ) + + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + + // check previously thrown error message + checkError( + e, + errorClass = "DELTA_FEATURES_REQUIRE_MANUAL_ENABLEMENT", + parameters = Map( + "unsupportedFeatures" -> VariantTypeTableFeature.name, + "supportedFeatures" -> currentFeatures + ) + ) + + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion + .withFeature(VariantTypeTableFeature) + .withFeature(InvariantsTableFeature) + .withFeature(AppendOnlyTableFeature) + ) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala index b19e29b52fa..6a95026d335 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala @@ -16,97 +16,4 @@ package org.apache.spark.sql.delta -import org.apache.spark.SparkThrowable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils} -import org.apache.spark.sql.delta.test.DeltaSQLCommandTest -import org.apache.spark.sql.types.StructType - -class DeltaVariantSuite - extends QueryTest - with DeltaSQLCommandTest { - - private def getProtocolForTable(table: String): Protocol = { - val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table)) - deltaLog.unsafeVolatileSnapshot.protocol - } - - test("create a new table with Variant, higher protocol and feature should be picked.") { - withTable("tbl") { - sql("CREATE TABLE tbl(s STRING, v VARIANT) USING DELTA") - sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") - // TODO(r.chen): Enable once variant casting is properly implemented in OSS Spark. - // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) - assert( - getProtocolForTable("tbl") == - VariantTypeTableFeature.minProtocolVersion.withFeature(VariantTypeTableFeature) - ) - } - } - - test("creating a table without Variant should use the usual minimum protocol") { - withTable("tbl") { - sql("CREATE TABLE tbl(s STRING, i INTEGER) USING DELTA") - assert(getProtocolForTable("tbl") == Protocol(1, 2)) - - val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) - assert( - !deltaLog.unsafeVolatileSnapshot.protocol.isFeatureSupported(VariantTypeTableFeature), - s"Table tbl contains VariantTypeFeature descriptor when its not supposed to" - ) - } - } - - test("add a new Variant column should upgrade to the correct protocol versions") { - withTable("tbl") { - sql("CREATE TABLE tbl(s STRING) USING delta") - assert(getProtocolForTable("tbl") == Protocol(1, 2)) - - // Should throw error - val e = intercept[SparkThrowable] { - sql("ALTER TABLE tbl ADD COLUMN v VARIANT") - } - // capture the existing protocol here. - // we will check the error message later in this test as we need to compare the - // expected schema and protocol - val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) - val currentProtocol = deltaLog.unsafeVolatileSnapshot.protocol - val currentFeatures = currentProtocol.implicitlyAndExplicitlySupportedFeatures - .map(_.name) - .toSeq - .sorted - .mkString(", ") - - // add table feature - sql( - s"ALTER TABLE tbl " + - s"SET TBLPROPERTIES('delta.feature.variantType-dev' = 'supported')" - ) - - sql("ALTER TABLE tbl ADD COLUMN v VARIANT") - - // check previously thrown error message - checkError( - e, - errorClass = "DELTA_FEATURES_REQUIRE_MANUAL_ENABLEMENT", - parameters = Map( - "unsupportedFeatures" -> VariantTypeTableFeature.name, - "supportedFeatures" -> currentFeatures - ) - ) - - sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") - // TODO(r.chen): Enable once variant casting is properly implemented in OSS Spark. - // assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) - - assert( - getProtocolForTable("tbl") == - VariantTypeTableFeature.minProtocolVersion - .withFeature(VariantTypeTableFeature) - .withFeature(InvariantsTableFeature) - .withFeature(AppendOnlyTableFeature) - ) - } - } -} +class DeltaVariantSuite extends DeltaVariantSparkOnlyTests From 265e886ec3ba7cd1f34c98246a5a9bba1415eeb0 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 11 Apr 2024 22:28:47 -0700 Subject: [PATCH 04/16] style --- spark/src/main/scala-spark-master/shims/VariantShim.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala-spark-master/shims/VariantShim.scala b/spark/src/main/scala-spark-master/shims/VariantShim.scala index 7853eb593c6..63285b58466 100644 --- a/spark/src/main/scala-spark-master/shims/VariantShim.scala +++ b/spark/src/main/scala-spark-master/shims/VariantShim.scala @@ -18,6 +18,6 @@ package org.apache.spark.sql.types object VariantShim { - /** Spark's variant type is only implemented in Spark 4.0 and above.*/ + /** Spark's variant type is only implemented in Spark 4.0 and above. */ def isTypeVariant(dt: DataType): Boolean = dt.isInstanceOf[VariantType] } From 7fddedfaef115fef00b0ac0637e117a27f8a32be Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Tue, 19 Mar 2024 18:55:51 -0700 Subject: [PATCH 05/16] fix compile --- .../io/delta/kernel/internal/SnapshotManagerSuite.scala | 4 ++-- .../io/delta/kernel/defaults/DeltaTableReadsSuite.scala | 4 ++-- .../io/delta/kernel/defaults/LogReplayMetricsSuite.scala | 2 +- .../scala/io/delta/kernel/defaults/utils/TestRow.scala | 2 +- .../scala/io/delta/kernel/defaults/utils/TestUtils.scala | 8 ++++---- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala index 71c2a256eee..6afe5a72939 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala @@ -617,7 +617,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // corrupt incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 10, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq val deltas = deltaFileStatuses(10L to 13L) testExpectedError[RuntimeException]( corruptedCheckpointStatuses ++ deltas, @@ -666,7 +666,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // _last_checkpoint refers to incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 20, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq testExpectedError[RuntimeException]( files = corruptedCheckpointStatuses ++ deltaFileStatuses(10L to 20L) ++ singularCheckpointFileStatuses(Seq(10L)), diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index 73017f4eb2a..d58bc5149c2 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -296,12 +296,12 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { Seq(TestRow(2), TestRow(2), TestRow(2)), TestRow("2", "2", TestRow(2, 2L)), "2" - ) :: Nil) + ) :: Nil).toSeq checkTable( path = path, expectedAnswer = expectedAnswer, - readCols = readCols + readCols = readCols.toSeq ) } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala index 51193d3ac90..f82eefe7b65 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala @@ -342,7 +342,7 @@ trait FileReadMetrics { self: Object => } } - def getVersionsRead: Seq[Long] = versionsRead + def getVersionsRead: Seq[Long] = versionsRead.toSeq def resetMetrics(): Unit = { versionsRead.clear() diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 661d286a3c9..d63ffa8f8eb 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -110,7 +110,7 @@ object TestRow { case _: StructType => TestRow(row.getStruct(i)) case _ => throw new UnsupportedOperationException("unrecognized data type") } - }) + }.toSeq) } def apply(row: SparkRow): TestRow = { diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index a3c9576aeb7..92fdd02a2e1 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -73,7 +73,7 @@ trait TestUtils extends Assertions with SQLHelper { while (iter.hasNext) { result.append(iter.next()) } - result + result.toSeq } finally { iter.close() } @@ -153,7 +153,7 @@ trait TestUtils extends Assertions with SQLHelper { // for all primitive types Seq(new Column((basePath :+ field.getName).asJava.toArray(new Array[String](0)))); case _ => Seq.empty - } + }.toSeq } def collectScanFileRows(scan: Scan, tableClient: TableClient = defaultTableClient): Seq[Row] = { @@ -231,7 +231,7 @@ trait TestUtils extends Assertions with SQLHelper { } } } - result + result.toSeq } /** @@ -626,7 +626,7 @@ trait TestUtils extends Assertions with SQLHelper { toSparkType(field.getDataType), field.isNullable ) - }) + }.toSeq) } } From d24a0b158ceb6fa9821c20e056d8e07deac36d79 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 11 Apr 2024 23:38:01 -0700 Subject: [PATCH 06/16] cross compile spark test dep for kernel defaults --- build.sbt | 53 +++++++++++-------- .../delta/kernel/defaults/utils/TestRow.scala | 3 +- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/build.sbt b/build.sbt index 3e79b965754..a2f5b91eaa5 100644 --- a/build.sbt +++ b/build.sbt @@ -38,6 +38,8 @@ val LATEST_RELEASED_SPARK_VERSION = "3.5.0" val SPARK_MASTER_VERSION = "4.0.0-SNAPSHOT" val sparkVersion = settingKey[String]("Spark version") spark / sparkVersion := getSparkVersion() +kernelDefaults / sparkVersion := getSparkVersion() +goldenTables / sparkVersion := getSparkVersion() // Dependent library versions val defaultSparkVersion = LATEST_RELEASED_SPARK_VERSION @@ -126,6 +128,25 @@ lazy val commonSettings = Seq( unidocSourceFilePatterns := Nil, ) +/** + * Java-/Scala-/Uni-Doc settings aren't working yet against Spark Master. + 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, + generating delta-spark unidoc compiles delta-iceberg + 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg + issue above. + */ +def crossSparkProjectSettings(): Seq[Setting[_]] = getSparkVersion() match { + case LATEST_RELEASED_SPARK_VERSION => Seq( + // Java-/Scala-/Uni-Doc Settings + scalacOptions ++= Seq( + "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc + ), + unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) + ) + + case SPARK_MASTER_VERSION => Seq() +} + /** * Note: we cannot access sparkVersion.value here, since that can only be used within a task or * setting macro. @@ -140,12 +161,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { Compile / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "main" / "scala-spark-3.5", Test / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "test" / "scala-spark-3.5", Antlr4 / antlr4Version := "4.9.3", - - // Java-/Scala-/Uni-Doc Settings - scalacOptions ++= Seq( - "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc - ), - unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) ) case SPARK_MASTER_VERSION => Seq( @@ -170,13 +185,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { "--add-opens=java.base/sun.security.action=ALL-UNNAMED", "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED" ) - - // Java-/Scala-/Uni-Doc Settings - // This isn't working yet against Spark Master. - // 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, - // generating delta-spark unidoc compiles delta-iceberg - // 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg - // issue above. ) } @@ -190,6 +198,7 @@ lazy val spark = (project in file("spark")) sparkMimaSettings, releaseSettings, crossSparkSettings(), + crossSparkProjectSettings(), libraryDependencies ++= Seq( // Adding test classifier seems to break transitive resolution of the core dependencies "org.apache.spark" %% "spark-hive" % sparkVersion.value % "provided", @@ -357,6 +366,7 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) scalaStyleSettings, javaOnlyReleaseSettings, Test / javaOptions ++= Seq("-ea"), + crossSparkSettings(), libraryDependencies ++= Seq( "org.apache.hadoop" % "hadoop-client-runtime" % hadoopVersion, "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5", @@ -373,10 +383,10 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) "org.openjdk.jmh" % "jmh-core" % "1.37" % "test", "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" % "test", - "org.apache.spark" %% "spark-hive" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", + "org.apache.spark" %% "spark-hive" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", ), javaCheckstyleSettings("kernel/dev/checkstyle.xml"), // Unidoc settings @@ -1071,14 +1081,15 @@ lazy val goldenTables = (project in file("connectors/golden-tables")) name := "golden-tables", commonSettings, skipReleaseSettings, + crossSparkSettings(), libraryDependencies ++= Seq( // Test Dependencies "org.scalatest" %% "scalatest" % scalaTestVersion % "test", "commons-io" % "commons-io" % "2.8.0" % "test", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests" + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests" ) ) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index d63ffa8f8eb..4a7da88d842 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -16,6 +16,7 @@ package io.delta.kernel.defaults.utils import scala.collection.JavaConverters._ +import scala.collection.mutable.{Seq => MutableSeq} import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} @@ -133,7 +134,7 @@ object TestRow { case _: sparktypes.BinaryType => obj.asInstanceOf[Array[Byte]] case _: sparktypes.DecimalType => obj.asInstanceOf[java.math.BigDecimal] case arrayType: sparktypes.ArrayType => - obj.asInstanceOf[Seq[Any]] + obj.asInstanceOf[MutableSeq[Any]] .map(decodeCellValue(arrayType.elementType, _)) case mapType: sparktypes.MapType => obj.asInstanceOf[Map[Any, Any]].map { case (k, v) => From f6aeb6b51bf8dd1acd5b294a2604b93f0b44a495 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 11 Mar 2024 23:14:53 -0700 Subject: [PATCH 07/16] init no tests yet --- .../io/delta/kernel/data/ColumnVector.java | 8 ++ .../main/java/io/delta/kernel/data/Row.java | 6 + .../io/delta/kernel/data/VariantValue.java | 25 ++++ .../delta/kernel/internal/TableFeatures.java | 1 + .../internal/data/ChildVectorBasedRow.java | 6 + .../kernel/internal/data/GenericRow.java | 8 ++ .../kernel/internal/util/VectorUtils.java | 2 + .../delta/kernel/types/BasePrimitiveType.java | 1 + .../io/delta/kernel/types/VariantType.java | 33 +++++ .../internal/data/DefaultJsonRow.java | 6 + .../data/vector/AbstractColumnVector.java | 6 + .../data/vector/DefaultGenericVector.java | 7 + .../data/vector/DefaultSubFieldVector.java | 7 + .../data/vector/DefaultVariantVector.java | 93 ++++++++++++++ .../data/vector/DefaultViewVector.java | 7 + .../parquet/ParquetColumnReaders.java | 2 + .../internal/parquet/ParquetSchemaUtils.java | 11 ++ .../internal/parquet/VariantConverter.java | 120 ++++++++++++++++++ .../integration/DataBuilderUtils.java | 6 + .../delta/kernel/defaults/utils/TestRow.scala | 2 + .../kernel/defaults/utils/TestUtils.scala | 1 + 21 files changed, 358 insertions(+) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java index 55c639a0161..2117794462c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java @@ -175,6 +175,14 @@ default ArrayValue getArray(int rowId) { throw new UnsupportedOperationException("Invalid value request for data type"); } + /** + * Return the variant value located at {@code rowId}. Returns null if the slot for {@code rowId} + * is null + */ + default VariantValue getVariant(int rowId) { + throw new UnsupportedOperationException("Invalid value request for data type"); + } + /** * Get the child vector associated with the given ordinal. This method is applicable only to the * {@code struct} type columns. diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java index adcacbc0f4c..560f3113d7b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java @@ -117,4 +117,10 @@ public interface Row { * Throws error if the column at given ordinal is not of map type, */ MapValue getMap(int ordinal); + + /** + * Return variant value of the column located at the given ordinal. + * Throws error if the column at given ordinal is not of variant type. + */ + VariantValue getVariant(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java new file mode 100644 index 00000000000..abf57d54502 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java @@ -0,0 +1,25 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.data; + +/** + * Abstraction to represent a single Variant value in a {@link ColumnVector}. + */ +public interface VariantValue { + byte[] getValue(); + + byte[] getMetadata(); +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java index 8cc350170f6..2bc107aa247 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java @@ -46,6 +46,7 @@ public static void validateReadSupportedTable(Protocol protocol, Metadata metada break; case "deletionVectors": // fall through case "timestampNtz": // fall through + case "variantType-dev": case "vacuumProtocolCheck": break; default: diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java index ae4793fa479..74e76f979b0 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java @@ -21,6 +21,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; /** @@ -111,5 +112,10 @@ public MapValue getMap(int ordinal) { return getChild(ordinal).getMap(rowId); } + @Override + public VariantValue getVariant(int ordinal) { + return getChild(ordinal).getVariant(rowId); + } + protected abstract ColumnVector getChild(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index 01a12bb84de..aabf6a56b08 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; /** @@ -134,6 +135,13 @@ public MapValue getMap(int ordinal) { return (MapValue) getValue(ordinal); } + @Override + public VariantValue getVariant(int ordinal) { + // TODO(r.chen): test this path somehow? + throwIfUnsafeAccess(ordinal, VariantType.class, "variant"); + return (VariantValue) getValue(ordinal); + } + private Object getValue(int ordinal) { return ordinalToValue.get(ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 70becebb42c..2269fd656b2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -110,6 +110,8 @@ private static Object getValueAsObject( return toJavaList(columnVector.getArray(rowId)); } else if (dataType instanceof MapType) { return toJavaMap(columnVector.getMap(rowId)); + } else if (dataType instanceof VariantType) { + return columnVector.getVariant(rowId); } else { throw new UnsupportedOperationException("unsupported data type"); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java index 33affe3d8aa..6141d84a139 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java @@ -64,6 +64,7 @@ public static List getAllPrimitiveTypes() { put("timestamp_ntz", TimestampNTZType.TIMESTAMP_NTZ); put("binary", BinaryType.BINARY); put("string", StringType.STRING); + put("variant", VariantType.VARIANT); } }); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java new file mode 100644 index 00000000000..81dbb072c97 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java @@ -0,0 +1,33 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.types; + +import io.delta.kernel.annotation.Evolving; + +/** + * A variant type. + *

+ * todo: more comments + * @since 4.0.0 + */ +@Evolving +public class VariantType extends BasePrimitiveType { + public static final VariantType VARIANT = new VariantType(); + + private VariantType() { + super("variant"); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java index bbada452c0a..8f60b3848ce 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java @@ -33,6 +33,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; import io.delta.kernel.internal.util.InternalUtils; @@ -128,6 +129,11 @@ public MapValue getMap(int ordinal) { return (MapValue) parsedValues[ordinal]; } + @Override + public VariantValue getVariant(int ordinal) { + throw new UnsupportedOperationException("not yet implemented"); + } + private static void throwIfTypeMismatch(String expType, boolean hasExpType, JsonNode jsonNode) { if (!hasExpType) { throw new RuntimeException( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java index 8196b93a178..55138cecd09 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java @@ -22,6 +22,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -138,6 +139,11 @@ public ArrayValue getArray(int rowId) { throw unsupportedDataAccessException("array"); } + @Override + public VariantValue getVariant(int rowId) { + throw unsupportedDataAccessException("variant"); + } + // TODO no need to override these here; update default implementations in `ColumnVector` // to have a more informative exception message protected UnsupportedOperationException unsupportedDataAccessException(String accessType) { diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java index aad449d1eb0..925a11746db 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -172,6 +172,13 @@ public ColumnVector getChild(int ordinal) { (rowId) -> (Row) rowIdToValueAccessor.apply(rowId)); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + throwIfUnsafeAccess(VariantType.class, "variant"); + return (VariantValue) rowIdToValueAccessor.apply(rowId); + } + private void throwIfUnsafeAccess( Class expDataType, String accessType) { if (!expDataType.isAssignableFrom(dataType.getClass())) { String msg = String.format( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java index a0aac0f12b1..1656d572f8a 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; @@ -155,6 +156,12 @@ public ArrayValue getArray(int rowId) { return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getVariant(columnOrdinal); + } + @Override public ColumnVector getChild(int childOrdinal) { StructType structType = (StructType) dataType; diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java new file mode 100644 index 00000000000..4f352c158d3 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -0,0 +1,93 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.util.Optional; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.VariantValue; +import io.delta.kernel.types.DataType; + +/** + * {@link io.delta.kernel.data.ColumnVector} implementation for variant type data. + */ +public class DefaultVariantVector + extends AbstractColumnVector { + private final ColumnVector valueVector; + private final ColumnVector metadataVector; + + + /** + * Create an instance of {@link io.delta.kernel.data.ColumnVector} for array type. + * + * @param size number of elements in the vector. + * @param nullability Optional array of nullability value for each element in the vector. + * All values in the vector are considered non-null when parameter is + * empty. + * @param offsets Offsets into element vector on where the index of particular row + * values start and end. + * @param elementVector Vector containing the array elements. + */ + public DefaultVariantVector( + int size, + DataType type, + Optional nullability, + ColumnVector value, + ColumnVector metadata) { + super(size, type, nullability); + // checkArgument(offsets.length >= size + 1, "invalid offset array size"); + this.valueVector = requireNonNull(value, "value is null"); + this.metadataVector = requireNonNull(metadata, "metadata is null"); + } + + /** + * Get the value at given {@code rowId}. The return value is undefined and can be + * anything, if the slot for {@code rowId} is null. + * + * @param rowId + * @return + */ + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + if (isNullAt(rowId)) { + return null; + } + return new VariantValue() { + private final byte[] value = valueVector.getBinary(rowId); + private final byte[] metadata = metadataVector.getBinary(rowId); + + @Override + public byte[] getValue() { + return value; + } + + @Override + public byte[] getMetadata() { + return metadata; + } + }; + } + + public ColumnVector getValueVector() { + return valueVector; + } + + public ColumnVector getMetadataVector() { + return metadataVector; + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java index 49c1fe00a2b..f256c6300c9 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -20,6 +20,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -137,6 +138,12 @@ public ArrayValue getArray(int rowId) { return underlyingVector.getArray(offset + rowId); } + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getVariant(offset + rowId); + } + @Override public ColumnVector getChild(int ordinal) { return new DefaultViewVector(underlyingVector.getChild(ordinal), offset, offset + size); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java index dce7c5244c3..516cbef0022 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java @@ -82,6 +82,8 @@ public static Converter createConverter( return createTimestampConverter(initialBatchSize, typeFromFile); } else if (typeFromClient instanceof TimestampNTZType) { return createTimestampNtzConverter(initialBatchSize, typeFromFile); + } else if (typeFromClient instanceof VariantType) { + return new VariantConverter(initialBatchSize); } throw new UnsupportedOperationException(typeFromClient + " is not supported"); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java index a86ee4323d0..a9f22e76967 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java @@ -235,6 +235,8 @@ private static Type toParquetType( type = toParquetMapType((MapType) dataType, name, repetition); } else if (dataType instanceof StructType) { type = toParquetStructType((StructType) dataType, name, repetition); + } else if (dataType instanceof VariantType) { + type = toParquetVariantType((VariantType) dataType, name, repetition); } else { throw new UnsupportedOperationException( "Writing given type data to Parquet is not supported: " + dataType); @@ -296,6 +298,15 @@ private static Type toParquetStructType(StructType structType, String name, return new GroupType(repetition, name, fields); } + private static Type toParquetVariantType(VariantType structType, String name, + Repetition repetition) { + List fields = Arrays.asList( + toParquetType(BinaryType.BINARY, "value", REQUIRED, Optional.empty()), + toParquetType(BinaryType.BINARY, "metadata", REQUIRED, Optional.empty()) + ); + return new GroupType(repetition, name, fields); + } + /** * Recursively checks whether the given data type has any Parquet field ids in it. */ diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java new file mode 100644 index 00000000000..dc4b61d28be --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java @@ -0,0 +1,120 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.parquet; + +import java.util.*; + +import org.apache.parquet.io.api.Converter; +import org.apache.parquet.io.api.GroupConverter; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.VariantType; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + +import io.delta.kernel.defaults.internal.data.vector.DefaultVariantVector; +import io.delta.kernel.defaults.internal.parquet.ParquetConverters.BinaryColumnConverter; + +class VariantConverter + extends GroupConverter + implements ParquetConverters.BaseConverter { + private final BinaryColumnConverter valueConverter; + private final BinaryColumnConverter metadataConverter; + + // Working state + private boolean isCurrentValueNull = true; + private int currentRowIndex; + private boolean[] nullability; + + /** + * Create converter for {@link VariantType} column. + * + * @param initialBatchSize Estimate of initial row batch size. Used in memory allocations. + */ + VariantConverter(int initialBatchSize) { + checkArgument(initialBatchSize > 0, "invalid initialBatchSize: %s", initialBatchSize); + // Initialize the working state + this.nullability = ParquetConverters.initNullabilityVector(initialBatchSize); + + int parquetOrdinal = 0; + this.valueConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); + this.metadataConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); + } + + @Override + public Converter getConverter(int fieldIndex) { + checkArgument( + fieldIndex >= 0 && fieldIndex < 2, + "variant type is represented by a struct with 2 fields"); + if (fieldIndex == 0) { + return valueConverter; + } else { + return metadataConverter; + } + } + + @Override + public void start() { + isCurrentValueNull = false; + } + + @Override + public void end() { + } + + @Override + public void finalizeCurrentRow(long currentRowIndex) { + resizeIfNeeded(); + finalizeLastRowInConverters(currentRowIndex); + nullability[this.currentRowIndex] = isCurrentValueNull; + isCurrentValueNull = true; + + this.currentRowIndex++; + } + + public ColumnVector getDataColumnVector(int batchSize) { + ColumnVector vector = new DefaultVariantVector( + batchSize, + VariantType.VARIANT, + Optional.of(nullability), + valueConverter.getDataColumnVector(batchSize), + metadataConverter.getDataColumnVector(batchSize) + ); + resetWorkingState(); + return vector; + } + + @Override + public void resizeIfNeeded() { + if (nullability.length == currentRowIndex) { + int newSize = nullability.length * 2; + this.nullability = Arrays.copyOf(this.nullability, newSize); + ParquetConverters.setNullabilityToTrue(this.nullability, newSize / 2, newSize); + } + } + + @Override + public void resetWorkingState() { + this.currentRowIndex = 0; + this.isCurrentValueNull = true; + this.nullability = ParquetConverters.initNullabilityVector(this.nullability.length); + } + + private void finalizeLastRowInConverters(long prevRowIndex) { + valueConverter.finalizeCurrentRow(prevRowIndex); + metadataConverter.finalizeCurrentRow(prevRowIndex); + } +} diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java index 1474bd2db67..5b76dc88687 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java @@ -26,6 +26,7 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -165,5 +166,10 @@ public MapValue getMap(int ordinal) { throw new UnsupportedOperationException( "map type unsupported for TestColumnBatchBuilder; use scala test utilities"); } + + @Override + public VariantValue getVariant(int ordinal) { + return (VariantValue) values.get(ordinal); + } } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 4a7da88d842..3c6b55e9912 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -109,6 +109,7 @@ object TestRow { case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) case _: MapType => mapValueToScalaMap(row.getMap(i)) case _: StructType => TestRow(row.getStruct(i)) + case _: VariantType => row.getVariant(i) case _ => throw new UnsupportedOperationException("unrecognized data type") } }.toSeq) @@ -174,6 +175,7 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(row.getStruct(i)) + case _: sparktypes.VariantType => row.getAs[Row](i) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 92fdd02a2e1..491f40fc93a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -627,6 +627,7 @@ trait TestUtils extends Assertions with SQLHelper { field.isNullable ) }.toSeq) + case VariantType.VARIANT => sparktypes.DataTypes.VariantType } } From c1bf2b1b174e7624a92f0801f9963779bd66d591 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 14 Mar 2024 17:50:09 -0600 Subject: [PATCH 08/16] fix reads --- .../data/vector/DefaultVariantVector.java | 1 - .../parquet/ParquetFileReaderSuite.scala | 54 +++++++++++++++++++ .../delta/kernel/defaults/utils/TestRow.scala | 13 +++-- .../kernel/defaults/utils/TestUtils.scala | 13 +++++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index 4f352c158d3..f2ceb1a592c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -49,7 +49,6 @@ public DefaultVariantVector( ColumnVector value, ColumnVector metadata) { super(size, type, nullability); - // checkArgument(offsets.length >= size + 1, "invalid offset array size"); this.valueVector = requireNonNull(value, "value is null"); this.metadataVector = requireNonNull(metadata, "metadata is null"); } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala index 941e2672cf0..a0cc9063d98 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala @@ -16,6 +16,9 @@ package io.delta.kernel.defaults.internal.parquet import java.math.BigDecimal + +import org.apache.spark.sql.DataFrame + import io.delta.golden.GoldenTableUtils.goldenTableFile import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow, VectorTestUtils} import io.delta.kernel.types._ @@ -139,4 +142,55 @@ class ParquetFileReaderSuite extends AnyFunSuite checkAnswer(actResult2, expResult2) } + + private def testReadVariant(testName: String)(df: => DataFrame): Unit = { + test(testName) { + withTable("test_variant_table") { + df.write + .format("delta") + .mode("overwrite") + .saveAsTable("test_variant_table") + val path = spark.sql("describe table extended `test_variant_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + val kernelSchema = tableSchema(path) + val actResult = readParquetFilesUsingKernel(path, kernelSchema) + val expResult = readParquetFilesUsingSpark(path, kernelSchema) + checkAnswer(actResult, expResult) + } + } + } + + testReadVariant("basic read variant") { + spark.range(0, 10, 1, 1).selectExpr( + "parse_json(cast(id as string)) as basic_v", + "named_struct('v', parse_json(cast(id as string))) as struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + parse_json(cast(id as string)) + ) as array_v""", + "map('test', parse_json(cast(id as string))) as map_value_v", + "map(parse_json(cast(id as string)), parse_json(cast(id as string))) as map_key_v" + ) + } + + testReadVariant("basic null variant") { + spark.range(0, 10, 1, 1).selectExpr( + "cast(null as variant) basic_v", + "named_struct('v', cast(null as variant)) as struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + null + ) as array_v""", + "map('test', cast(null as variant)) as map_value_v", + "map(cast(null as variant), parse_json(cast(id as string))) as map_key_v", + ) + } + + // TODO(richardc-db): Add nested variant tests once `parse_json` expression is implemented. } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 3c6b55e9912..af6cdb2734a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -19,6 +19,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{Seq => MutableSeq} import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} +import org.apache.spark.unsafe.types.VariantVal import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} import io.delta.kernel.types._ @@ -45,7 +46,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset} * - ArrayType --> Seq[Any] * - MapType --> Map[Any, Any] * - StructType --> TestRow - * + * - VariantType --> VariantVal * For complex types array and map, the inner elements types should align with this mapping. */ class TestRow(val values: Array[Any]) { @@ -109,7 +110,9 @@ object TestRow { case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) case _: MapType => mapValueToScalaMap(row.getMap(i)) case _: StructType => TestRow(row.getStruct(i)) - case _: VariantType => row.getVariant(i) + case _: VariantType => + val kernelVariant = row.getVariant(i) + new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata()) case _ => throw new UnsupportedOperationException("unrecognized data type") } }.toSeq) @@ -142,6 +145,7 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(obj.asInstanceOf[SparkRow]) + case _: sparktypes.VariantType => obj.asInstanceOf[VariantVal] case _ => throw new UnsupportedOperationException("unrecognized data type") } } @@ -175,7 +179,7 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(row.getStruct(i)) - case _: sparktypes.VariantType => row.getAs[Row](i) + case _: sparktypes.VariantType => row.getAs[VariantVal](i) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) @@ -207,6 +211,9 @@ object TestRow { TestRow.fromSeq(Seq.range(0, dataType.length()).map { ordinal => getAsTestObject(vector.getChild(ordinal), rowId) }) + case _: VariantType => + val kernelVariant = vector.getVariant(rowId) + new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata()) case _ => throw new UnsupportedOperationException("unrecognized data type") } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 491f40fc93a..c76eb34f36e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.unsafe.types.VariantVal import org.scalatest.Assertions trait TestUtils extends Assertions with SQLHelper { @@ -117,6 +118,17 @@ trait TestUtils extends Assertions with SQLHelper { lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader } + /** + * Drops table `tableName` after calling `f`. + */ + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + def withGoldenTable(tableName: String)(testFunc: String => Unit): Unit = { val tablePath = GoldenTableUtils.goldenTablePath(tableName) testFunc(tablePath) @@ -404,6 +416,7 @@ trait TestUtils extends Assertions with SQLHelper { java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) case (a: Float, b: Float) => java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a: VariantVal, b: VariantVal) => a.debugString() == b.debugString() case (a, b) => if (!a.equals(b)) { val sds = 200; From 809e11be60e6029c1c2521bdcaaf1066858f6092 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 14 Mar 2024 21:00:35 -0600 Subject: [PATCH 09/16] changename --- .../parquet/ParquetFileReaderSuite.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala index a0cc9063d98..1c0bb2d8d83 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala @@ -143,14 +143,18 @@ class ParquetFileReaderSuite extends AnyFunSuite checkAnswer(actResult2, expResult2) } - private def testReadVariant(testName: String)(df: => DataFrame): Unit = { + /** + * Writes a table using Spark, reads it back using the Delta Kernel implementation, and asserts + * that the results are the same. + */ + private def testRead(testName: String)(df: => DataFrame): Unit = { test(testName) { - withTable("test_variant_table") { + withTable("test_table") { df.write .format("delta") .mode("overwrite") - .saveAsTable("test_variant_table") - val path = spark.sql("describe table extended `test_variant_table`") + .saveAsTable("test_table") + val path = spark.sql("describe table extended `test_table`") .where("col_name = 'Location'") .collect()(0) .getString(1) @@ -164,7 +168,7 @@ class ParquetFileReaderSuite extends AnyFunSuite } } - testReadVariant("basic read variant") { + testRead("basic read variant") { spark.range(0, 10, 1, 1).selectExpr( "parse_json(cast(id as string)) as basic_v", "named_struct('v', parse_json(cast(id as string))) as struct_v", @@ -178,7 +182,7 @@ class ParquetFileReaderSuite extends AnyFunSuite ) } - testReadVariant("basic null variant") { + testRead("basic null variant") { spark.range(0, 10, 1, 1).selectExpr( "cast(null as variant) basic_v", "named_struct('v', cast(null as variant)) as struct_v", From 627ed6f36484f64cb94101fe1c70773d0492f2bb Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Sun, 17 Mar 2024 15:24:31 -0600 Subject: [PATCH 10/16] use get child instead --- .../internal/data/vector/DefaultVariantVector.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index f2ceb1a592c..68e5e188a40 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -82,11 +82,13 @@ public byte[] getMetadata() { }; } - public ColumnVector getValueVector() { - return valueVector; - } - - public ColumnVector getMetadataVector() { - return metadataVector; + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(ordinal >= 0 && ordinal < 2, "Invalid ordinal " + ordinal); + if (ordinal == 0) { + return valueVector; + } else { + return metadataVector; + } } } From d8d355b5160a7e48605713c56b56aa86f95b197f Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 18 Mar 2024 12:15:53 -0600 Subject: [PATCH 11/16] fix compile --- .../defaults/internal/data/vector/DefaultVariantVector.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index 68e5e188a40..10a95fc5677 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -22,6 +22,8 @@ import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + /** * {@link io.delta.kernel.data.ColumnVector} implementation for variant type data. */ From cf85e0870f02e652724a750818f87411dbfc28e0 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 18 Mar 2024 12:36:07 -0600 Subject: [PATCH 12/16] fix --- .../internal/parquet/ParquetSchemaUtils.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java index a9f22e76967..149e1892479 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java @@ -236,7 +236,7 @@ private static Type toParquetType( } else if (dataType instanceof StructType) { type = toParquetStructType((StructType) dataType, name, repetition); } else if (dataType instanceof VariantType) { - type = toParquetVariantType((VariantType) dataType, name, repetition); + type = toParquetVariantType(name, repetition); } else { throw new UnsupportedOperationException( "Writing given type data to Parquet is not supported: " + dataType); @@ -298,13 +298,11 @@ private static Type toParquetStructType(StructType structType, String name, return new GroupType(repetition, name, fields); } - private static Type toParquetVariantType(VariantType structType, String name, - Repetition repetition) { - List fields = Arrays.asList( - toParquetType(BinaryType.BINARY, "value", REQUIRED, Optional.empty()), - toParquetType(BinaryType.BINARY, "metadata", REQUIRED, Optional.empty()) - ); - return new GroupType(repetition, name, fields); + private static Type toParquetVariantType(String name, Repetition repetition) { + return Types.buildGroup(repetition) + .addField(toParquetType(BinaryType.BINARY, "value", REQUIRED, Optional.empty())) + .addField(toParquetType(BinaryType.BINARY, "metadata", REQUIRED, Optional.empty())) + .named(name); } /** From e01a0538bc783ee81a34c2608b1978b7fdc23573 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Tue, 19 Mar 2024 22:44:35 -0700 Subject: [PATCH 13/16] changes --- .../io/delta/kernel/internal/data/GenericRow.java | 1 - .../java/io/delta/kernel/types/VariantType.java | 4 +--- .../internal/data/vector/DefaultVariantVector.java | 13 ++++++++++--- .../defaults/internal/parquet/VariantConverter.java | 8 ++++---- .../internal/parquet/ParquetFileReaderSuite.scala | 5 +---- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index aabf6a56b08..c4d6aeaf8ca 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -137,7 +137,6 @@ public MapValue getMap(int ordinal) { @Override public VariantValue getVariant(int ordinal) { - // TODO(r.chen): test this path somehow? throwIfUnsafeAccess(ordinal, VariantType.class, "variant"); return (VariantValue) getValue(ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java index 81dbb072c97..71a84cdb718 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java @@ -18,9 +18,7 @@ import io.delta.kernel.annotation.Evolving; /** - * A variant type. - *

- * todo: more comments + * A logical variant type. * @since 4.0.0 */ @Evolving diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index 10a95fc5677..3a21e8481a0 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -37,12 +37,12 @@ public class DefaultVariantVector * Create an instance of {@link io.delta.kernel.data.ColumnVector} for array type. * * @param size number of elements in the vector. + * @param type {@code variant} datatype definition. * @param nullability Optional array of nullability value for each element in the vector. * All values in the vector are considered non-null when parameter is * empty. - * @param offsets Offsets into element vector on where the index of particular row - * values start and end. - * @param elementVector Vector containing the array elements. + * @param value The child binary column vector representing each variant's values. + * @param metadata The child binary column vector representing each variant's metadata. */ public DefaultVariantVector( int size, @@ -84,6 +84,13 @@ public byte[] getMetadata() { }; } + /** + * Get the child column vector at the given {@code ordinal}. Variants should only have two + * child vectors, one for value and one for metadata. + * + * @param ordinal + * @return + */ @Override public ColumnVector getChild(int ordinal) { checkArgument(ordinal >= 0 && ordinal < 2, "Invalid ordinal " + ordinal); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java index dc4b61d28be..b099953ec2f 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java @@ -34,10 +34,12 @@ class VariantConverter private final BinaryColumnConverter valueConverter; private final BinaryColumnConverter metadataConverter; - // Working state - private boolean isCurrentValueNull = true; + // working state private int currentRowIndex; private boolean[] nullability; + // If the value is null, start/end never get called which is a signal for null + // Set the initial state to true and when start() is called set it to false. + private boolean isCurrentValueNull = true; /** * Create converter for {@link VariantType} column. @@ -46,10 +48,8 @@ class VariantConverter */ VariantConverter(int initialBatchSize) { checkArgument(initialBatchSize > 0, "invalid initialBatchSize: %s", initialBatchSize); - // Initialize the working state this.nullability = ParquetConverters.initNullabilityVector(initialBatchSize); - int parquetOrdinal = 0; this.valueConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); this.metadataConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala index 1c0bb2d8d83..b76b8d0ec01 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala @@ -191,10 +191,7 @@ class ParquetFileReaderSuite extends AnyFunSuite parse_json(cast(id as string)), null ) as array_v""", - "map('test', cast(null as variant)) as map_value_v", - "map(cast(null as variant), parse_json(cast(id as string))) as map_key_v", + "map('test', cast(null as variant)) as map_value_v" ) } - - // TODO(richardc-db): Add nested variant tests once `parse_json` expression is implemented. } From 5f5519fb7326358c5e397abcdac8a06e9987b0f4 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 20 Mar 2024 20:46:46 -0700 Subject: [PATCH 14/16] use defaultvariantvalue --- .../data/value/DefaultVariantValue.java | 61 +++++++++++++++++++ .../data/vector/DefaultVariantVector.java | 27 ++++---- .../delta/kernel/defaults/utils/TestRow.scala | 9 +-- .../kernel/defaults/utils/TestUtils.scala | 2 - 4 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java new file mode 100644 index 00000000000..5b2e7329b73 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java @@ -0,0 +1,61 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.value; + +import java.util.Arrays; + +import io.delta.kernel.data.VariantValue; + +/** + * Default implementation of a Delta kernel VariantValue. + */ +public class DefaultVariantValue implements VariantValue { + private final byte[] value; + private final byte[] metadata; + + public DefaultVariantValue(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + + @Override + public byte[] getValue() { + return value; + } + + @Override + public byte[] getMetadata() { + return metadata; + } + + public String debugString() { + return "VariantValue{value=" + Arrays.toString(value) + + ", metadata=" + Arrays.toString(metadata) + '}'; + } + /** + * Compare two variants in bytes. The variant equality is more complex than it, and we haven't + * supported it in the user surface yet. This method is only intended for tests. + */ + @Override + public boolean equals(Object other) { + if (other instanceof DefaultVariantValue) { + return Arrays.equals(value, ((DefaultVariantValue) other).getValue()) && + Arrays.equals(metadata, ((DefaultVariantValue) other).getMetadata()); + } else { + return false; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index 3a21e8481a0..f4f35edd991 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -23,6 +23,7 @@ import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue; /** * {@link io.delta.kernel.data.ColumnVector} implementation for variant type data. @@ -68,20 +69,22 @@ public VariantValue getVariant(int rowId) { if (isNullAt(rowId)) { return null; } - return new VariantValue() { - private final byte[] value = valueVector.getBinary(rowId); - private final byte[] metadata = metadataVector.getBinary(rowId); + // return new VariantValue() { + // private final byte[] value = valueVector.getBinary(rowId); + // private final byte[] metadata = metadataVector.getBinary(rowId); - @Override - public byte[] getValue() { - return value; - } + // @Override + // public byte[] getValue() { + // return value; + // } - @Override - public byte[] getMetadata() { - return metadata; - } - }; + // @Override + // public byte[] getMetadata() { + // return metadata; + // } + // }; + return new DefaultVariantValue( + valueVector.getBinary(rowId), metadataVector.getBinary(rowId)); } /** diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index af6cdb2734a..c2f6cf0398e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} import org.apache.spark.unsafe.types.VariantVal import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue import io.delta.kernel.types._ import java.sql.Timestamp @@ -110,9 +111,7 @@ object TestRow { case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) case _: MapType => mapValueToScalaMap(row.getMap(i)) case _: StructType => TestRow(row.getStruct(i)) - case _: VariantType => - val kernelVariant = row.getVariant(i) - new VariantVal(kernelVariant.getValue(), kernelVariant.getMetadata()) + case _: VariantType => row.getVariant(i) case _ => throw new UnsupportedOperationException("unrecognized data type") } }.toSeq) @@ -179,7 +178,9 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(row.getStruct(i)) - case _: sparktypes.VariantType => row.getAs[VariantVal](i) + case _: sparktypes.VariantType => + val sparkVariant = row.getAs[VariantVal](i) + new DefaultVariantValue(sparkVariant.getValue(), sparkVariant.getMetadata()) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index c76eb34f36e..7a5720ec778 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -41,7 +41,6 @@ import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.unsafe.types.VariantVal import org.scalatest.Assertions trait TestUtils extends Assertions with SQLHelper { @@ -416,7 +415,6 @@ trait TestUtils extends Assertions with SQLHelper { java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) case (a: Float, b: Float) => java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) - case (a: VariantVal, b: VariantVal) => a.debugString() == b.debugString() case (a, b) => if (!a.equals(b)) { val sds = 200; From 5722e8da238891dfa633ca70e480247e55be48b4 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 25 Mar 2024 14:45:18 -0700 Subject: [PATCH 15/16] rebase --- .../data/value/DefaultVariantValue.java | 4 +++- .../data/vector/DefaultVariantVector.java | 13 ----------- .../parquet/ParquetColumnReaders.java | 2 +- ...onverter.java => VariantColumnReader.java} | 23 ++++++++++--------- 4 files changed, 16 insertions(+), 26 deletions(-) rename kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/{VariantConverter.java => VariantColumnReader.java} (80%) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java index 5b2e7329b73..c4c283e9609 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java @@ -41,10 +41,12 @@ public byte[] getMetadata() { return metadata; } - public String debugString() { + @Override + public String toString() { return "VariantValue{value=" + Arrays.toString(value) + ", metadata=" + Arrays.toString(metadata) + '}'; } + /** * Compare two variants in bytes. The variant equality is more complex than it, and we haven't * supported it in the user surface yet. This method is only intended for tests. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java index f4f35edd991..a12c426fda3 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -69,20 +69,7 @@ public VariantValue getVariant(int rowId) { if (isNullAt(rowId)) { return null; } - // return new VariantValue() { - // private final byte[] value = valueVector.getBinary(rowId); - // private final byte[] metadata = metadataVector.getBinary(rowId); - // @Override - // public byte[] getValue() { - // return value; - // } - - // @Override - // public byte[] getMetadata() { - // return metadata; - // } - // }; return new DefaultVariantValue( valueVector.getBinary(rowId), metadataVector.getBinary(rowId)); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java index 516cbef0022..05308eebc1e 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java @@ -83,7 +83,7 @@ public static Converter createConverter( } else if (typeFromClient instanceof TimestampNTZType) { return createTimestampNtzConverter(initialBatchSize, typeFromFile); } else if (typeFromClient instanceof VariantType) { - return new VariantConverter(initialBatchSize); + return new VariantColumnReader(initialBatchSize); } throw new UnsupportedOperationException(typeFromClient + " is not supported"); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantColumnReader.java similarity index 80% rename from kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java rename to kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantColumnReader.java index b099953ec2f..24132058337 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantConverter.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/VariantColumnReader.java @@ -26,13 +26,14 @@ import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.defaults.internal.data.vector.DefaultVariantVector; -import io.delta.kernel.defaults.internal.parquet.ParquetConverters.BinaryColumnConverter; +import io.delta.kernel.defaults.internal.parquet.ParquetColumnReaders.BaseColumnReader; +import io.delta.kernel.defaults.internal.parquet.ParquetColumnReaders.BinaryColumnReader; -class VariantConverter +class VariantColumnReader extends GroupConverter - implements ParquetConverters.BaseConverter { - private final BinaryColumnConverter valueConverter; - private final BinaryColumnConverter metadataConverter; + implements BaseColumnReader { + private final BinaryColumnReader valueConverter; + private final BinaryColumnReader metadataConverter; // working state private int currentRowIndex; @@ -46,12 +47,12 @@ class VariantConverter * * @param initialBatchSize Estimate of initial row batch size. Used in memory allocations. */ - VariantConverter(int initialBatchSize) { + VariantColumnReader(int initialBatchSize) { checkArgument(initialBatchSize > 0, "invalid initialBatchSize: %s", initialBatchSize); - this.nullability = ParquetConverters.initNullabilityVector(initialBatchSize); + this.nullability = ParquetColumnReaders.initNullabilityVector(initialBatchSize); - this.valueConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); - this.metadataConverter = new BinaryColumnConverter(BinaryType.BINARY, initialBatchSize); + this.valueConverter = new BinaryColumnReader(BinaryType.BINARY, initialBatchSize); + this.metadataConverter = new BinaryColumnReader(BinaryType.BINARY, initialBatchSize); } @Override @@ -102,7 +103,7 @@ public void resizeIfNeeded() { if (nullability.length == currentRowIndex) { int newSize = nullability.length * 2; this.nullability = Arrays.copyOf(this.nullability, newSize); - ParquetConverters.setNullabilityToTrue(this.nullability, newSize / 2, newSize); + ParquetColumnReaders.setNullabilityToTrue(this.nullability, newSize / 2, newSize); } } @@ -110,7 +111,7 @@ public void resizeIfNeeded() { public void resetWorkingState() { this.currentRowIndex = 0; this.isCurrentValueNull = true; - this.nullability = ParquetConverters.initNullabilityVector(this.nullability.length); + this.nullability = ParquetColumnReaders.initNullabilityVector(this.nullability.length); } private void finalizeLastRowInConverters(long prevRowIndex) { From cebc36f39fe13ba75188b664cb774107a9751dda Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Mon, 15 Apr 2024 18:12:32 -0700 Subject: [PATCH 16/16] poc --- build.sbt | 2 + .../src/main/java/io/delta/kernel/Scan.java | 39 ++++++++ .../java/io/delta/kernel/ScanBuilder.java | 7 ++ .../internal/ExtractedVariantOptions.java | 32 +++++++ .../kernel/internal/ScanBuilderImpl.java | 32 ++++++- .../io/delta/kernel/internal/ScanImpl.java | 6 +- .../kernel/internal/data/GenericRow.java | 3 +- .../kernel/internal/data/ScanStateRow.java | 19 +++- .../kernel/internal/util/VariantUtils.java | 91 +++++++++++++++++++ .../io/delta/kernel/types/StructField.java | 15 +++ .../DefaultExpressionEvaluator.java | 26 ++++++ .../expressions/ExpressionVisitor.java | 4 + .../parquet/ParquetColumnReaders.java | 2 + .../io/delta/kernel/defaults/ScanSuite.scala | 58 +++++++++++- 14 files changed, 330 insertions(+), 6 deletions(-) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/internal/ExtractedVariantOptions.java create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java diff --git a/build.sbt b/build.sbt index a2f5b91eaa5..b12ba34d256 100644 --- a/build.sbt +++ b/build.sbt @@ -369,6 +369,8 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) crossSparkSettings(), libraryDependencies ++= Seq( "org.apache.hadoop" % "hadoop-client-runtime" % hadoopVersion, + // can we cross compile spark-variant? + // "org.apache.spark" %% "spark-variant" % SPARK_MASTER_VERSION % "provided", "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5", "org.apache.parquet" % "parquet-hadoop" % "1.12.3", diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index 4b776254c81..2189a1dacd9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -17,7 +17,9 @@ package io.delta.kernel; import java.io.IOException; +import java.util.List; import java.util.Optional; +import java.util.stream.IntStream; import io.delta.kernel.annotation.Evolving; import io.delta.kernel.client.TableClient; @@ -27,6 +29,7 @@ import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; +import io.delta.kernel.internal.ExtractedVariantOptions; import io.delta.kernel.internal.InternalScanFileUtils; import io.delta.kernel.internal.actions.DeletionVectorDescriptor; import io.delta.kernel.internal.data.ScanStateRow; @@ -36,6 +39,7 @@ import io.delta.kernel.internal.util.ColumnMapping; import io.delta.kernel.internal.util.PartitionUtils; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.internal.util.VariantUtils; /** * Represents a scan of a Delta table. @@ -134,6 +138,7 @@ static CloseableIterator transformPhysicalData( StructType physicalReadSchema = null; StructType logicalReadSchema = null; String tablePath = null; + List extractedVariantOptions = null; RoaringBitmapArray currBitmap = null; DeletionVectorDescriptor currDV = null; @@ -144,6 +149,7 @@ private void initIfRequired() { } physicalReadSchema = ScanStateRow.getPhysicalSchema(tableClient, scanState); logicalReadSchema = ScanStateRow.getLogicalSchema(tableClient, scanState); + extractedVariantOptions = ScanStateRow.getExtractedVariantFields(scanState); tablePath = ScanStateRow.getTableRoot(scanState); inited = true; @@ -203,6 +209,19 @@ public FilteredColumnarBatch next() { physicalReadSchema ); + // Add extracted variant columns.= + if (extractedVariantOptions.size() > 0) { + nextDataBatch = VariantUtils.withExtractedVariantFields( + tableClient.getExpressionHandler(), + nextDataBatch, + extractedVariantOptions + ); + } + + // TODO: Implement default columnarBatch slice methods to make this more efficient. + // Remove added variant columns required for scan. + nextDataBatch = removeInternallyAddedVariantCols(nextDataBatch, physicalReadSchema); + // Change back to logical schema String columnMappingMode = ScanStateRow.getColumnMappingMode(scanState); switch (columnMappingMode) { @@ -219,6 +238,26 @@ public FilteredColumnarBatch next() { return new FilteredColumnarBatch(nextDataBatch, selectionVector); } + + private ColumnarBatch removeInternallyAddedVariantCols( + ColumnarBatch batch, + StructType schema) { + int numToRemove = (int) schema.fields().stream() + .filter(field -> field.isInternallyAddedVariant()) + .count(); + + // There is no guarantee that `ColumnarBatch.withDeletedColumnAt` doesn't reorder + // the schema so the added variant columns must be removed one by one. + for (int i = 0; i < numToRemove; i++) { + Optional idxToRemove = IntStream.range(0, schema.length()) + .filter(idx -> schema.at(idx).isInternallyAddedVariant()) + .boxed() + .findFirst(); + + batch = batch.withDeletedColumnAt(idxToRemove.get().intValue()); + } + return batch; + } }; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java b/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java index deff60fdd16..8eea1f7c72c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java @@ -19,6 +19,7 @@ import io.delta.kernel.annotation.Evolving; import io.delta.kernel.client.TableClient; import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructType; /** @@ -49,6 +50,12 @@ public interface ScanBuilder { */ ScanBuilder withReadSchema(TableClient tableClient, StructType readSchema); + ScanBuilder withExtractedVariantField( + TableClient tableClient, + String path, + DataType type, + String extractedFieldName); + /** * @return Build the {@link Scan instance} */ diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ExtractedVariantOptions.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ExtractedVariantOptions.java new file mode 100644 index 00000000000..62018ed5471 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ExtractedVariantOptions.java @@ -0,0 +1,32 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.internal; + +import io.delta.kernel.expressions.Column; +import io.delta.kernel.types.DataType; + +public class ExtractedVariantOptions { + public Column path; + public String fieldName; + public DataType type; + + public ExtractedVariantOptions(Column path, DataType type, String fieldName) { + this.path = path; + this.fieldName = fieldName; + this.type = type; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java index ffb88b629ea..9777539a9b1 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java @@ -16,13 +16,16 @@ package io.delta.kernel.internal; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import io.delta.kernel.Scan; import io.delta.kernel.ScanBuilder; import io.delta.kernel.client.TableClient; +import io.delta.kernel.expressions.Column; import io.delta.kernel.expressions.Predicate; -import io.delta.kernel.types.StructType; +import io.delta.kernel.types.*; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; @@ -44,6 +47,7 @@ public class ScanBuilderImpl private StructType readSchema; private Optional predicate; + private List extractedVariantFields; public ScanBuilderImpl( Path dataPath, @@ -60,6 +64,7 @@ public ScanBuilderImpl( this.tableClient = tableClient; this.readSchema = snapshotSchema; this.predicate = Optional.empty(); + this.extractedVariantFields = new ArrayList(); } @Override @@ -78,6 +83,25 @@ public ScanBuilder withReadSchema(TableClient tableClient, StructType readSchema return this; } + @Override + public ScanBuilder withExtractedVariantField( + TableClient tableClient, + String path, + DataType type, + String extractedFieldName) { + String[] splitPath = splitVariantPath(path); + extractedVariantFields.add(new ExtractedVariantOptions( + new Column(splitPath), type, extractedFieldName)); + + // TODO: were attaching the actual variant column name right now. + // Will this work with column mapping/is there a more robust way? + if (readSchema.indexOf(splitPath[0]) == -1) { + readSchema = readSchema.add(StructField.internallyAddedVariantSchema(splitPath[0])); + } + + return this; + } + @Override public Scan build() { return new ScanImpl( @@ -87,6 +111,12 @@ public Scan build() { metadata, logReplay, predicate, + extractedVariantFields, dataPath); } + + private String[] splitVariantPath(String path) { + // TODO: account for square brackets and array indices later. + return path.split("\\."); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 229b7cb8aeb..ef3938b499e 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -58,6 +58,7 @@ public class ScanImpl implements Scan { private final LogReplay logReplay; private final Path dataPath; private final Optional> partitionAndDataFilters; + private final List extractedVariantFields; private final Supplier> partitionColToStructFieldMap; private boolean accessedScanFiles; @@ -68,6 +69,7 @@ public ScanImpl( Metadata metadata, LogReplay logReplay, Optional filter, + List extractedVariantFields, Path dataPath) { this.snapshotSchema = snapshotSchema; this.readSchema = readSchema; @@ -85,6 +87,7 @@ public ScanImpl( field -> field.getName().toLowerCase(Locale.ROOT), identity())); }; + this.extractedVariantFields = extractedVariantFields; } /** @@ -157,7 +160,8 @@ public Row getScanState(TableClient tableClient) { readSchema.toJson(), physicalReadSchema.toJson(), physicalDataReadSchema.toJson(), - dataPath.toUri().toString()); + dataPath.toUri().toString(), + extractedVariantFields); } @Override diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index c4d6aeaf8ca..e052c09ea7b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -141,7 +141,8 @@ public VariantValue getVariant(int ordinal) { return (VariantValue) getValue(ordinal); } - private Object getValue(int ordinal) { + // TODO: HACK to not have to serialize and deserialize the ExtractedVarinatOptions list. + public Object getValue(int ordinal) { return ordinalToValue.get(ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java index 3bff36b264f..9fe9af27107 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java @@ -24,6 +24,7 @@ import io.delta.kernel.data.Row; import io.delta.kernel.types.*; +import io.delta.kernel.internal.ExtractedVariantOptions; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; import io.delta.kernel.internal.util.ColumnMapping; @@ -41,7 +42,12 @@ public class ScanStateRow extends GenericRow { .add("partitionColumns", new ArrayType(StringType.STRING, false)) .add("minReaderVersion", IntegerType.INTEGER) .add("minWriterVersion", IntegerType.INTEGER) - .add("tablePath", StringType.STRING); + .add("tablePath", StringType.STRING) + .add("extractedVariantFields", new ArrayType(new StructType() + .add("path", StringType.STRING, false) + .add("fieldName", StringType.STRING, false) + .add("type", StringType.STRING, false), false) + ); private static final Map COL_NAME_TO_ORDINAL = IntStream.range(0, SCHEMA.length()) @@ -54,7 +60,8 @@ public static ScanStateRow of( String readSchemaLogicalJson, String readSchemaPhysicalJson, String readPhysicalDataSchemaJson, - String tablePath) { + String tablePath, + List extractedVariantOptions) { HashMap valueMap = new HashMap<>(); valueMap.put(COL_NAME_TO_ORDINAL.get("configuration"), metadata.getConfigurationMapValue()); valueMap.put(COL_NAME_TO_ORDINAL.get("logicalSchemaString"), readSchemaLogicalJson); @@ -65,6 +72,8 @@ public static ScanStateRow of( valueMap.put(COL_NAME_TO_ORDINAL.get("minReaderVersion"), protocol.getMinReaderVersion()); valueMap.put(COL_NAME_TO_ORDINAL.get("minWriterVersion"), protocol.getMinWriterVersion()); valueMap.put(COL_NAME_TO_ORDINAL.get("tablePath"), tablePath); + // TODO: HACK to not have to serialize and deserialize the ExtractedVarinatOptions list. + valueMap.put(COL_NAME_TO_ORDINAL.get("extractedVariantFields"), extractedVariantOptions); return new ScanStateRow(valueMap); } @@ -147,4 +156,10 @@ public static String getColumnMappingMode(Row scanState) { public static String getTableRoot(Row scanState) { return scanState.getString(COL_NAME_TO_ORDINAL.get("tablePath")); } + + // TODO: HACK to not have to serialize and deserialize the ExtractedVarinatOptions list. + public static List getExtractedVariantFields(Row scanState) { + return (List) ((ScanStateRow) scanState).getValue( + COL_NAME_TO_ORDINAL.get("extractedVariantFields")); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java new file mode 100644 index 00000000000..ff61f7da5bc --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java @@ -0,0 +1,91 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.internal.util; + +import java.util.Arrays; +import java.util.List; + +import io.delta.kernel.client.ExpressionHandler; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.expressions.*; +import io.delta.kernel.types.*; + +import io.delta.kernel.internal.ExtractedVariantOptions; + +/** + * Doc comment + */ +public class VariantUtils { + + /** + * Doc comment + */ + public static ColumnarBatch withExtractedVariantFields( + ExpressionHandler expressionHandler, + ColumnarBatch dataBatch, + List extractedVariantFields) { + for (ExtractedVariantOptions opts : extractedVariantFields) { + // TODO: how does this work with column mapping? We're searching for the fieldName in + // the delta schema + String varColName = opts.path.getNames()[0]; + int colIdx = dataBatch.getSchema().indexOf(varColName); + if (colIdx == -1) { + System.out.println("TOP LEVEL VARIANT IS NOT FOUND IN SCHEMA"); + assert false; + } + + ExpressionEvaluator evaluator = expressionHandler.getEvaluator( + getVariantGetExprSchema(varColName), + new ScalarExpression( + "variant_get", + Arrays.asList( + new Column(varColName), + Literal.ofString(String.join(".", opts.path.getNames())), + // TODO: Does "toString" work on more complex types + Literal.ofString(opts.type.toString())) + ), + opts.type + ); + + dataBatch = dataBatch.withNewColumn( + dataBatch.getSchema().length(), + // TODO: set this to the right datatype. + new StructField(opts.fieldName, StringType.STRING, true), + // TODO: we don't have to pass in the whole data batch, we only need the + // variant column. + extractVariantField(evaluator, dataBatch, opts) + ); + } + + return dataBatch; + } + + private static ColumnVector extractVariantField( + ExpressionEvaluator evaluator, + ColumnarBatch batchWithVariantCol, + ExtractedVariantOptions options) { + return evaluator.eval(batchWithVariantCol); + } + + private static StructType getVariantGetExprSchema(String variantColName) { + return new StructType() + .add(new StructField(variantColName, VariantType.VARIANT, true)) + .add(new StructField("col_1", StringType.STRING, false)) + .add(new StructField("col_2", StringType.STRING, false)); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java index bba4e950324..82b66a0c13d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java @@ -37,6 +37,8 @@ public class StructField { */ private static String IS_METADATA_COLUMN_KEY = "isMetadataColumn"; + private static String IS_INTERNALLY_ADDED_VARIANT = "isInternallyAddedVariant"; + /** * The name of a row index metadata column. When present this column must be populated with * row index of each row when reading from parquet. @@ -48,6 +50,14 @@ public class StructField { false, FieldMetadata.builder().putBoolean(IS_METADATA_COLUMN_KEY, true).build()); + public static StructField internallyAddedVariantSchema(String fieldName) { + return new StructField( + fieldName, + VariantType.VARIANT, + true, + FieldMetadata.builder().putBoolean(IS_INTERNALLY_ADDED_VARIANT, true).build()); + } + //////////////////////////////////////////////////////////////////////////////// // Instance Fields / Methods @@ -109,6 +119,11 @@ public boolean isMetadataColumn() { (boolean) metadata.get(IS_METADATA_COLUMN_KEY); } + public boolean isInternallyAddedVariant() { + return metadata.contains(IS_INTERNALLY_ADDED_VARIANT) && + (boolean) metadata.get(IS_INTERNALLY_ADDED_VARIANT); + } + public boolean isDataColumn() { return !isMetadataColumn(); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index f801b67b34d..a3d485b43ec 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.defaults.internal.expressions; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -278,6 +279,21 @@ ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) { ); } + @Override + ExpressionTransformResult visitVariantGet(ScalarExpression variantGet) { + Expression transformedVariantInput = visit(childAt(variantGet, 0)).expression; + Expression transformedPath = visit(childAt(variantGet, 1)).expression; + Expression transformedType = visit(childAt(variantGet, 2)).expression; + + // TODO: actually finish this. Do validations and cast anything if necessary. + return new ExpressionTransformResult( + new ScalarExpression( + "VARIANT_GET", + Arrays.asList(transformedVariantInput, transformedPath, transformedType)), + // TODO Hardcoded to string type output + StringType.STRING); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { @@ -558,6 +574,16 @@ ColumnVector visitCoalesce(ScalarExpression coalesce) { ); } + @Override + ColumnVector visitVariantGet(ScalarExpression variantGet) { + ColumnVector variantResult = visit(childAt(variantGet, 0)); + ColumnVector pathResult = visit(childAt(variantGet, 1)); + ColumnVector dataTypeResult = visit(childAt(variantGet, 2)); + + // TODO: actually implement variant_get rather than returning the path. + return pathResult; + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index bd219f55fda..2393f8dc9e8 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -59,6 +59,8 @@ abstract class ExpressionVisitor { abstract R visitCoalesce(ScalarExpression ifNull); + abstract R visitVariantGet(ScalarExpression variantGet); + final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); @@ -105,6 +107,8 @@ private R visitScalarExpression(ScalarExpression expression) { return visitIsNull(new Predicate(name, children)); case "COALESCE": return visitCoalesce(expression); + case "VARIANT_GET": + return visitVariantGet(expression); default: throw new UnsupportedOperationException( String.format("Scalar expression `%s` is not supported.", name)); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java index 05308eebc1e..6d726b65c91 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java @@ -83,6 +83,8 @@ public static Converter createConverter( } else if (typeFromClient instanceof TimestampNTZType) { return createTimestampNtzConverter(initialBatchSize, typeFromFile); } else if (typeFromClient instanceof VariantType) { + // TODO: When shredding happens, we need to pass in the `typeFromFile` to read the + // shredded paths. return new VariantColumnReader(initialBatchSize); } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index e08adaadbde..06be771c557 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -43,7 +43,7 @@ import io.delta.kernel.{Snapshot, Table} import io.delta.kernel.internal.util.InternalUtils import io.delta.kernel.internal.InternalScanFileUtils import io.delta.kernel.defaults.client.{DefaultJsonHandler, DefaultParquetHandler, DefaultTableClient} -import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestUtils} +import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow, TestUtils} class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with SQLHelper { @@ -1541,6 +1541,62 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with tableClient = tableClient) } } + + import io.delta.kernel.Scan + import io.delta.kernel.internal.util.Utils.toCloseableIterator + import io.delta.kernel.internal.data.ScanStateRow + import io.delta.kernel.internal.InternalScanFileUtils + import io.delta.kernel.internal.util.Utils.singletonCloseableIterator + test("extract variant") { + withTable("test_table") { + spark.range(0, 10).selectExpr("parse_json(cast(id as string)) as v").write + .format("delta") + .mode("overwrite") + .saveAsTable("test_table") + val path = spark.sql("describe table extended `test_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + val kernelSchema = tableSchema(path) + + val snapshot = latestSnapshot(path) + val scan = snapshot + .getScanBuilder(defaultTableClient) + .withReadSchema(defaultTableClient, new StructType()) + .withExtractedVariantField(defaultTableClient, "v", STRING, "extractedField") + .build() + val scanState = scan.getScanState(defaultTableClient) + val physicalReadSchema = ScanStateRow.getPhysicalDataReadSchema(defaultTableClient, scanState) + val scanFilesIter = scan.getScanFiles(defaultTableClient) + while (scanFilesIter.hasNext()) { + val scanFilesBatch = scanFilesIter.next() + val scanFileRows = scanFilesBatch.getRows() + while (scanFileRows.hasNext()) { + val scanFileRow = scanFileRows.next() + val fileStatus = InternalScanFileUtils.getAddFileStatus(scanFileRow) + + val physicalDataIter = defaultTableClient.getParquetHandler.readParquetFiles( + singletonCloseableIterator(fileStatus), + physicalReadSchema, + Optional.empty()) + + val transformedRowsIter = Scan.transformPhysicalData( + defaultTableClient, + scanState, + scanFileRow, + physicalDataIter + ) + + val transformedRows = transformedRowsIter + .asScala.toSeq.map(_.getRows).flatMap(_.toSeq).map(TestRow(_)) + + transformedRows.foreach(println) + } + } + } + } } object ScanSuite {