diff --git a/build.sbt b/build.sbt index c63954d9040..30cd0c938ae 100644 --- a/build.sbt +++ b/build.sbt @@ -38,6 +38,8 @@ val LATEST_RELEASED_SPARK_VERSION = "3.5.0" val SPARK_MASTER_VERSION = "4.0.0-SNAPSHOT" val sparkVersion = settingKey[String]("Spark version") spark / sparkVersion := getSparkVersion() +kernelDefaults / sparkVersion := getSparkVersion() +goldenTables / sparkVersion := getSparkVersion() // Dependent library versions val defaultSparkVersion = LATEST_RELEASED_SPARK_VERSION @@ -126,6 +128,25 @@ lazy val commonSettings = Seq( unidocSourceFilePatterns := Nil, ) +/** + * Java-/Scala-/Uni-Doc settings aren't working yet against Spark Master. + 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, + generating delta-spark unidoc compiles delta-iceberg + 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg + issue above. + */ +def crossSparkProjectSettings(): Seq[Setting[_]] = getSparkVersion() match { + case LATEST_RELEASED_SPARK_VERSION => Seq( + // Java-/Scala-/Uni-Doc Settings + scalacOptions ++= Seq( + "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc + ), + unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) + ) + + case SPARK_MASTER_VERSION => Seq() +} + /** * Note: we cannot access sparkVersion.value here, since that can only be used within a task or * setting macro. @@ -140,12 +161,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { Compile / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "main" / "scala-spark-3.5", Test / unmanagedSourceDirectories += (Test / baseDirectory).value / "src" / "test" / "scala-spark-3.5", Antlr4 / antlr4Version := "4.9.3", - - // Java-/Scala-/Uni-Doc Settings - scalacOptions ++= Seq( - "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc - ), - unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) ) case SPARK_MASTER_VERSION => Seq( @@ -170,13 +185,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { "--add-opens=java.base/sun.security.action=ALL-UNNAMED", "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED" ) - - // Java-/Scala-/Uni-Doc Settings - // This isn't working yet against Spark Master. - // 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, - // generating delta-spark unidoc compiles delta-iceberg - // 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg - // issue above. ) } @@ -190,6 +198,7 @@ lazy val spark = (project in file("spark")) sparkMimaSettings, releaseSettings, crossSparkSettings(), + crossSparkProjectSettings(), libraryDependencies ++= Seq( // Adding test classifier seems to break transitive resolution of the core dependencies "org.apache.spark" %% "spark-hive" % sparkVersion.value % "provided", @@ -358,6 +367,7 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) scalaStyleSettings, javaOnlyReleaseSettings, Test / javaOptions ++= Seq("-ea"), + crossSparkSettings(), libraryDependencies ++= Seq( "org.apache.hadoop" % "hadoop-client-runtime" % hadoopVersion, "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5", @@ -374,10 +384,10 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) "org.openjdk.jmh" % "jmh-core" % "1.37" % "test", "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" % "test", - "org.apache.spark" %% "spark-hive" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", + "org.apache.spark" %% "spark-hive" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", ), javaCheckstyleSettings("kernel/dev/checkstyle.xml"), // Unidoc settings @@ -1072,14 +1082,15 @@ lazy val goldenTables = (project in file("connectors/golden-tables")) name := "golden-tables", commonSettings, skipReleaseSettings, + crossSparkSettings(), libraryDependencies ++= Seq( // Test Dependencies "org.scalatest" %% "scalatest" % scalaTestVersion % "test", "commons-io" % "commons-io" % "2.8.0" % "test", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests" + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests" ) ) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index 4b776254c81..c41a9781ee4 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -36,6 +36,7 @@ import io.delta.kernel.internal.util.ColumnMapping; import io.delta.kernel.internal.util.PartitionUtils; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.internal.util.VariantUtils; /** * Represents a scan of a Delta table. @@ -194,6 +195,13 @@ public FilteredColumnarBatch next() { nextDataBatch = nextDataBatch.withDeletedColumnAt(rowIndexOrdinal); } + // Transform physical variant columns (struct of binaries) into logical variant + // columns. + nextDataBatch = VariantUtils.withVariantColumns( + tableClient.getExpressionHandler(), + nextDataBatch + ); + // Add partition columns nextDataBatch = PartitionUtils.withPartitionColumns( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java index 55c639a0161..2117794462c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java @@ -175,6 +175,14 @@ default ArrayValue getArray(int rowId) { throw new UnsupportedOperationException("Invalid value request for data type"); } + /** + * Return the variant value located at {@code rowId}. Returns null if the slot for {@code rowId} + * is null + */ + default VariantValue getVariant(int rowId) { + throw new UnsupportedOperationException("Invalid value request for data type"); + } + /** * Get the child vector associated with the given ordinal. This method is applicable only to the * {@code struct} type columns. diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java index adcacbc0f4c..560f3113d7b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java @@ -117,4 +117,10 @@ public interface Row { * Throws error if the column at given ordinal is not of map type, */ MapValue getMap(int ordinal); + + /** + * Return variant value of the column located at the given ordinal. + * Throws error if the column at given ordinal is not of variant type. + */ + VariantValue getVariant(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java new file mode 100644 index 00000000000..abf57d54502 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java @@ -0,0 +1,25 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.data; + +/** + * Abstraction to represent a single Variant value in a {@link ColumnVector}. + */ +public interface VariantValue { + byte[] getValue(); + + byte[] getMetadata(); +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java index 13339dfcb44..20822255546 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java @@ -50,8 +50,8 @@ public static void validateReadSupportedTable(Protocol protocol, Metadata metada break; case "deletionVectors": // fall through case "timestampNtz": // fall through - case "vacuumProtocolCheck": // fall through - case "v2Checkpoint": + case "variantType-dev": // fall through + case "vacuumProtocolCheck": break; default: throw DeltaErrors.unsupportedReadFeature(3, readerFeature); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java index ae4793fa479..74e76f979b0 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java @@ -21,6 +21,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; /** @@ -111,5 +112,10 @@ public MapValue getMap(int ordinal) { return getChild(ordinal).getMap(rowId); } + @Override + public VariantValue getVariant(int ordinal) { + return getChild(ordinal).getVariant(rowId); + } + protected abstract ColumnVector getChild(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index 01a12bb84de..c4d6aeaf8ca 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; /** @@ -134,6 +135,12 @@ public MapValue getMap(int ordinal) { return (MapValue) getValue(ordinal); } + @Override + public VariantValue getVariant(int ordinal) { + throwIfUnsafeAccess(ordinal, VariantType.class, "variant"); + return (VariantValue) getValue(ordinal); + } + private Object getValue(int ordinal) { return ordinalToValue.get(ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java new file mode 100644 index 00000000000..64ec7e670fd --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java @@ -0,0 +1,70 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.internal.util; + +import java.util.Arrays; + +import io.delta.kernel.client.ExpressionHandler; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.expressions.*; +import io.delta.kernel.types.*; + +public class VariantUtils { + public static ColumnarBatch withVariantColumns( + ExpressionHandler expressionHandler, + ColumnarBatch dataBatch) { + for (int i = 0; i < dataBatch.getSchema().length(); i++) { + StructField field = dataBatch.getSchema().at(i); + if (!(field.getDataType() instanceof StructType) && + !(field.getDataType() instanceof ArrayType) && + !(field.getDataType() instanceof MapType) && + (field.getDataType() != VariantType.VARIANT || + dataBatch.getColumnVector(i).getDataType() == VariantType.VARIANT)) { + continue; + } + + ExpressionEvaluator evaluator = expressionHandler.getEvaluator( + // Field here is variant type if its actually a variant. + // TODO: probably better to pass in the schema as an argument + // so the schema is enforced at the expression level. Need to pass in a literal + // schema + new StructType().add(field), + new ScalarExpression( + "variant_coalesce", + Arrays.asList(new Column(field.getName())) + ), + VariantType.VARIANT + ); + + // TODO: don't need to pass in the entire batch. + ColumnVector variantCol = evaluator.eval(dataBatch); + // TODO: make a more efficient way to do this. + dataBatch = + dataBatch.withDeletedColumnAt(i).withNewColumn(i, field, variantCol); + } + return dataBatch; + } + + private static ColumnVector[] getColumnBatchVectors(ColumnarBatch batch) { + ColumnVector[] res = new ColumnVector[batch.getSchema().length()]; + for (int i = 0; i < batch.getSchema().length(); i++) { + res[i] = batch.getColumnVector(i); + } + return res; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 7a44836d29a..2859480c3e0 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -202,6 +202,8 @@ private static Object getValueAsObject( return toJavaList(columnVector.getArray(rowId)); } else if (dataType instanceof MapType) { return toJavaMap(columnVector.getMap(rowId)); + } else if (dataType instanceof VariantType) { + return columnVector.getVariant(rowId); } else { throw new UnsupportedOperationException("unsupported data type"); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java index 33affe3d8aa..6141d84a139 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java @@ -64,6 +64,7 @@ public static List getAllPrimitiveTypes() { put("timestamp_ntz", TimestampNTZType.TIMESTAMP_NTZ); put("binary", BinaryType.BINARY); put("string", StringType.STRING); + put("variant", VariantType.VARIANT); } }); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java new file mode 100644 index 00000000000..71a84cdb718 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java @@ -0,0 +1,31 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.types; + +import io.delta.kernel.annotation.Evolving; + +/** + * A logical variant type. + * @since 4.0.0 + */ +@Evolving +public class VariantType extends BasePrimitiveType { + public static final VariantType VARIANT = new VariantType(); + + private VariantType() { + super("variant"); + } +} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala index e4a3136bebf..833ebf80ea4 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala @@ -759,7 +759,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // corrupt incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 10, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq val deltas = deltaFileStatuses(10L to 13L) testExpectedError[RuntimeException]( corruptedCheckpointStatuses ++ deltas, @@ -808,7 +808,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // _last_checkpoint refers to incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 20, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq testExpectedError[RuntimeException]( files = corruptedCheckpointStatuses ++ deltaFileStatuses(10L to 20L) ++ singularCheckpointFileStatuses(Seq(10L)), diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java index bbada452c0a..8f60b3848ce 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java @@ -33,6 +33,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; import io.delta.kernel.internal.util.InternalUtils; @@ -128,6 +129,11 @@ public MapValue getMap(int ordinal) { return (MapValue) parsedValues[ordinal]; } + @Override + public VariantValue getVariant(int ordinal) { + throw new UnsupportedOperationException("not yet implemented"); + } + private static void throwIfTypeMismatch(String expType, boolean hasExpType, JsonNode jsonNode) { if (!hasExpType) { throw new RuntimeException( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java new file mode 100644 index 00000000000..c4c283e9609 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java @@ -0,0 +1,63 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.value; + +import java.util.Arrays; + +import io.delta.kernel.data.VariantValue; + +/** + * Default implementation of a Delta kernel VariantValue. + */ +public class DefaultVariantValue implements VariantValue { + private final byte[] value; + private final byte[] metadata; + + public DefaultVariantValue(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + + @Override + public byte[] getValue() { + return value; + } + + @Override + public byte[] getMetadata() { + return metadata; + } + + @Override + public String toString() { + return "VariantValue{value=" + Arrays.toString(value) + + ", metadata=" + Arrays.toString(metadata) + '}'; + } + + /** + * Compare two variants in bytes. The variant equality is more complex than it, and we haven't + * supported it in the user surface yet. This method is only intended for tests. + */ + @Override + public boolean equals(Object other) { + if (other instanceof DefaultVariantValue) { + return Arrays.equals(value, ((DefaultVariantValue) other).getValue()) && + Arrays.equals(metadata, ((DefaultVariantValue) other).getMetadata()); + } else { + return false; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java index 8196b93a178..7d8a9976813 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java @@ -22,6 +22,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -78,6 +79,10 @@ public boolean isNullAt(int rowId) { return nullability.get()[rowId]; } + public Optional getNullability() { + return nullability; + } + @Override public boolean getBoolean(int rowId) { throw unsupportedDataAccessException("boolean"); @@ -138,6 +143,11 @@ public ArrayValue getArray(int rowId) { throw unsupportedDataAccessException("array"); } + @Override + public VariantValue getVariant(int rowId) { + throw unsupportedDataAccessException("variant"); + } + // TODO no need to override these here; update default implementations in `ColumnVector` // to have a more informative exception message protected UnsupportedOperationException unsupportedDataAccessException(String accessType) { diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java index bdadea61100..e7f5c5a25c1 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java @@ -88,4 +88,12 @@ public ColumnVector getElements() { } }; } + + public ColumnVector getElementVector() { + return elementVector; + } + + public int[] getOffsets() { + return offsets; + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java index aad449d1eb0..925a11746db 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -172,6 +172,13 @@ public ColumnVector getChild(int ordinal) { (rowId) -> (Row) rowIdToValueAccessor.apply(rowId)); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + throwIfUnsafeAccess(VariantType.class, "variant"); + return (VariantValue) rowIdToValueAccessor.apply(rowId); + } + private void throwIfUnsafeAccess( Class expDataType, String accessType) { if (!expDataType.isAssignableFrom(dataType.getClass())) { String msg = String.format( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java index ee0a0d4feba..403bcd56582 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java @@ -98,4 +98,16 @@ public ColumnVector getValues() { } }; } + + public ColumnVector getKeyVector() { + return keyVector; + } + + public ColumnVector getValueVector() { + return valueVector; + } + + public int[] getOffsets() { + return offsets; + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java index a0aac0f12b1..1656d572f8a 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; @@ -155,6 +156,12 @@ public ArrayValue getArray(int rowId) { return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getVariant(columnOrdinal); + } + @Override public ColumnVector getChild(int childOrdinal) { StructType structType = (StructType) dataType; diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java new file mode 100644 index 00000000000..a12c426fda3 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -0,0 +1,93 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.util.Optional; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.VariantValue; +import io.delta.kernel.types.DataType; + +import static io.delta.kernel.internal.util.Preconditions.checkArgument; +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue; + +/** + * {@link io.delta.kernel.data.ColumnVector} implementation for variant type data. + */ +public class DefaultVariantVector + extends AbstractColumnVector { + private final ColumnVector valueVector; + private final ColumnVector metadataVector; + + + /** + * Create an instance of {@link io.delta.kernel.data.ColumnVector} for array type. + * + * @param size number of elements in the vector. + * @param type {@code variant} datatype definition. + * @param nullability Optional array of nullability value for each element in the vector. + * All values in the vector are considered non-null when parameter is + * empty. + * @param value The child binary column vector representing each variant's values. + * @param metadata The child binary column vector representing each variant's metadata. + */ + public DefaultVariantVector( + int size, + DataType type, + Optional nullability, + ColumnVector value, + ColumnVector metadata) { + super(size, type, nullability); + this.valueVector = requireNonNull(value, "value is null"); + this.metadataVector = requireNonNull(metadata, "metadata is null"); + } + + /** + * Get the value at given {@code rowId}. The return value is undefined and can be + * anything, if the slot for {@code rowId} is null. + * + * @param rowId + * @return + */ + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + if (isNullAt(rowId)) { + return null; + } + + return new DefaultVariantValue( + valueVector.getBinary(rowId), metadataVector.getBinary(rowId)); + } + + /** + * Get the child column vector at the given {@code ordinal}. Variants should only have two + * child vectors, one for value and one for metadata. + * + * @param ordinal + * @return + */ + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(ordinal >= 0 && ordinal < 2, "Invalid ordinal " + ordinal); + if (ordinal == 0) { + return valueVector; + } else { + return metadataVector; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java index 49c1fe00a2b..f256c6300c9 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -20,6 +20,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -137,6 +138,12 @@ public ArrayValue getArray(int rowId) { return underlyingVector.getArray(offset + rowId); } + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getVariant(offset + rowId); + } + @Override public ColumnVector getChild(int ordinal) { return new DefaultViewVector(underlyingVector.getChild(ordinal), offset, offset + size); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index f801b67b34d..b0a864e7fd3 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.defaults.internal.expressions; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -33,8 +34,7 @@ import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild; import static io.delta.kernel.internal.util.Preconditions.checkArgument; -import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; -import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; +import io.delta.kernel.defaults.internal.data.vector.*; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare; @@ -48,6 +48,7 @@ */ public class DefaultExpressionEvaluator implements ExpressionEvaluator { private final Expression expression; + private final StructType inputSchema; /** * Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and @@ -68,12 +69,14 @@ public DefaultExpressionEvaluator( "Can not create an expression handler returns result of type %s", outputType); throw DeltaErrors.unsupportedExpression(expression, Optional.of(reason)); } + // TODO(richardc-db): Hack to avoid needing to pass the schema into the expression. + this.inputSchema = inputSchema; this.expression = transformResult.expression; } @Override public ColumnVector eval(ColumnarBatch input) { - return new ExpressionEvalVisitor(input).visit(expression); + return new ExpressionEvalVisitor(input, inputSchema).visit(expression); } @Override @@ -278,6 +281,21 @@ ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) { ); } + @Override + ExpressionTransformResult visitVariantCoalesce(ScalarExpression variantCoalesce) { + checkArgument( + variantCoalesce.getChildren().size() == 1, + "Expected one input to 'variant_coalesce but received %s", + variantCoalesce.getChildren().size() + ); + Expression transformedVariantInput = visit(childAt(variantCoalesce, 0)).expression; + return new ExpressionTransformResult( + new ScalarExpression( + "VARIANT_COALESCE", + Arrays.asList(transformedVariantInput)), + VariantType.VARIANT); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { @@ -318,9 +336,11 @@ private Expression transformBinaryComparator(Predicate predicate) { */ private static class ExpressionEvalVisitor extends ExpressionVisitor { private final ColumnarBatch input; + private final StructType inputSchema; - ExpressionEvalVisitor(ColumnarBatch input) { + ExpressionEvalVisitor(ColumnarBatch input, StructType inputSchema) { this.input = input; + this.inputSchema = inputSchema; } /* @@ -558,6 +578,119 @@ ColumnVector visitCoalesce(ScalarExpression coalesce) { ); } + @Override + ColumnVector visitVariantCoalesce(ScalarExpression variantCoalesce) { + return variantCoalesceImpl( + visit(childAt(variantCoalesce, 0)), + inputSchema.at(0).getDataType() + ); + } + + private ColumnVector variantCoalesceImpl(ColumnVector inputVec, DataType dt) { + if (dt instanceof StructType) { + StructType structType = (StructType) dt; + DefaultStructVector structVec = (DefaultStructVector) inputVec; + ColumnVector[] structColVecs = new ColumnVector[structType.length()]; + for (int i = 0; i < structType.length(); i++) { + if (structType.at(i).getDataType() instanceof ArrayType || + structType.at(i).getDataType() instanceof StructType || + structType.at(i).getDataType() instanceof MapType || + structType.at(i).getDataType() instanceof VariantType) { + structColVecs[i] = variantCoalesceImpl( + structVec.getChild(i), + structType.at(i).getDataType() + ); + } else { + structColVecs[i] = structVec.getChild(i); + } + } + return new DefaultStructVector( + structVec.getSize(), + structType, + structVec.getNullability(), + structColVecs + ); + } + + if (dt instanceof ArrayType) { + ArrayType arrType = (ArrayType) dt; + DefaultArrayVector arrVec = (DefaultArrayVector) inputVec; + + if (arrType.getElementType() instanceof ArrayType || + arrType.getElementType() instanceof StructType || + arrType.getElementType() instanceof MapType || + arrType.getElementType() instanceof VariantType) { + ColumnVector elementVec = variantCoalesceImpl( + arrVec.getElementVector(), + arrType.getElementType() + ); + + return new DefaultArrayVector( + arrVec.getSize(), + arrType, + arrVec.getNullability(), + arrVec.getOffsets(), + elementVec + ); + } + return arrVec; + } + + if (dt instanceof MapType) { + MapType mapType = (MapType) dt; + DefaultMapVector mapVec = (DefaultMapVector) inputVec; + + ColumnVector valueVec = mapVec.getValueVector(); + if (mapType.getValueType() instanceof ArrayType || + mapType.getValueType() instanceof StructType || + mapType.getValueType() instanceof MapType || + mapType.getValueType() instanceof VariantType) { + valueVec = variantCoalesceImpl( + mapVec.getValueVector(), + mapType.getValueType() + ); + } + ColumnVector keyVec = mapVec.getKeyVector(); + if (mapType.getKeyType() instanceof ArrayType || + mapType.getKeyType() instanceof StructType || + mapType.getKeyType() instanceof MapType || + mapType.getKeyType() instanceof VariantType) { + keyVec = variantCoalesceImpl( + mapVec.getKeyVector(), + mapType.getKeyType() + ); + } + return new DefaultMapVector( + mapVec.getSize(), + mapType, + mapVec.getNullability(), + mapVec.getOffsets(), + keyVec, + valueVec + ); + } + + DefaultStructVector structBackingVariant = (DefaultStructVector) inputVec; + checkArgument( + structBackingVariant.getChild(0).getDataType() instanceof BinaryType, + "Expected struct field 0 backing variant to be binary type. Actual: %s", + structBackingVariant.getChild(0).getDataType() + ); + checkArgument( + structBackingVariant.getChild(1).getDataType() instanceof BinaryType, + "Expected struct field 1 backing variant to be binary type. Actual: %s", + structBackingVariant.getChild(1).getDataType() + ); + + return new DefaultVariantVector( + structBackingVariant.getSize(), + VariantType.VARIANT, + structBackingVariant.getNullability(), + structBackingVariant.getChild(0), + structBackingVariant.getChild(1) + ); + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index bd219f55fda..d715f8ff918 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -59,6 +59,8 @@ abstract class ExpressionVisitor { abstract R visitCoalesce(ScalarExpression ifNull); + abstract R visitVariantCoalesce(ScalarExpression variantCoalesce); + final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); @@ -105,6 +107,8 @@ private R visitScalarExpression(ScalarExpression expression) { return visitIsNull(new Predicate(name, children)); case "COALESCE": return visitCoalesce(expression); + case "VARIANT_COALESCE": + return visitVariantCoalesce(expression); default: throw new UnsupportedOperationException( String.format("Scalar expression `%s` is not supported.", name)); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java index dce7c5244c3..03416ddae7a 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java @@ -82,6 +82,13 @@ public static Converter createConverter( return createTimestampConverter(initialBatchSize, typeFromFile); } else if (typeFromClient instanceof TimestampNTZType) { return createTimestampNtzConverter(initialBatchSize, typeFromFile); + } else if (typeFromClient instanceof VariantType) { + return new RowColumnReader( + initialBatchSize, + new StructType() + .add("value", BinaryType.BINARY, false) + .add("metadata", BinaryType.BINARY, false), + (GroupType) typeFromFile); } throw new UnsupportedOperationException(typeFromClient + " is not supported"); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java index f3712b18fb7..0d71ffe9f3b 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java @@ -242,6 +242,8 @@ private static Type toParquetType( type = toParquetMapType((MapType) dataType, name, repetition); } else if (dataType instanceof StructType) { type = toParquetStructType((StructType) dataType, name, repetition); + } else if (dataType instanceof VariantType) { + type = toParquetVariantType(name, repetition); } else { throw new UnsupportedOperationException( "Writing given type data to Parquet is not supported: " + dataType); @@ -303,6 +305,13 @@ private static Type toParquetStructType(StructType structType, String name, return new GroupType(repetition, name, fields); } + private static Type toParquetVariantType(String name, Repetition repetition) { + return Types.buildGroup(repetition) + .addField(toParquetType(BinaryType.BINARY, "value", REQUIRED, Optional.empty())) + .addField(toParquetType(BinaryType.BINARY, "metadata", REQUIRED, Optional.empty())) + .named(name); + } + /** * Recursively checks whether the given data type has any Parquet field ids in it. */ diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java index 1474bd2db67..5b76dc88687 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java @@ -26,6 +26,7 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -165,5 +166,10 @@ public MapValue getMap(int ordinal) { throw new UnsupportedOperationException( "map type unsupported for TestColumnBatchBuilder; use scala test utilities"); } + + @Override + public VariantValue getVariant(int ordinal) { + return (VariantValue) values.get(ordinal); + } } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index a071338226d..d57d8ccc7a0 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -303,12 +303,12 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { Seq(TestRow(2), TestRow(2), TestRow(2)), TestRow("2", "2", TestRow(2, 2L)), "2" - ) :: Nil) + ) :: Nil).toSeq checkTable( path = path, expectedAnswer = expectedAnswer, - readCols = readCols + readCols = readCols.toSeq ) } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala index 51193d3ac90..f82eefe7b65 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayMetricsSuite.scala @@ -342,7 +342,7 @@ trait FileReadMetrics { self: Object => } } - def getVersionsRead: Seq[Long] = versionsRead + def getVersionsRead: Seq[Long] = versionsRead.toSeq def resetMetrics(): Unit = { versionsRead.clear() diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index 4a88d8d1400..6e6f76700e7 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -21,18 +21,21 @@ import java.time.{Instant, OffsetDateTime} import java.time.temporal.ChronoUnit import java.util.Optional +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import io.delta.golden.GoldenTableUtils.goldenTablePath import org.apache.hadoop.conf.Configuration -import org.apache.spark.sql.{Row => SparkRow} +import org.apache.spark.sql.{DataFrame, Row => SparkRow} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog} import org.apache.spark.sql.types.{IntegerType => SparkIntegerType, StructField => SparkStructField, StructType => SparkStructType} import org.scalatest.funsuite.AnyFunSuite +import io.delta.kernel.Scan import io.delta.kernel.client.{JsonHandler, ParquetHandler, TableClient} import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row} +import io.delta.kernel.defaults.utils.TestRow import io.delta.kernel.expressions.{AlwaysFalse, AlwaysTrue, And, Column, Or, Predicate, ScalarExpression} import io.delta.kernel.expressions.Literal._ import io.delta.kernel.types.StructType @@ -40,8 +43,11 @@ import io.delta.kernel.types.StringType.STRING import io.delta.kernel.types.IntegerType.INTEGER import io.delta.kernel.utils.{CloseableIterator, FileStatus} import io.delta.kernel.{Scan, Snapshot, Table} -import io.delta.kernel.internal.util.InternalUtils import io.delta.kernel.internal.{InternalScanFileUtils, ScanImpl} +import io.delta.kernel.internal.data.ScanStateRow +import io.delta.kernel.internal.util.InternalUtils +import io.delta.kernel.internal.util.Utils.singletonCloseableIterator +import io.delta.kernel.internal.util.Utils.toCloseableIterator import io.delta.kernel.defaults.client.{DefaultJsonHandler, DefaultParquetHandler, DefaultTableClient} import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestUtils} @@ -1588,6 +1594,99 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with ) } } + + private def testReadWithVariant(testName: String)(df: => DataFrame): Unit = { + test(testName) { + withTable("test_table") { + df.write + .format("delta") + .mode("overwrite") + .saveAsTable("test_table") + val path = spark.sql("describe table extended `test_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + val kernelSchema = tableSchema(path) + + val snapshot = latestSnapshot(path) + val scan = snapshot.getScanBuilder(defaultTableClient).build() + val scanState = scan.getScanState(defaultTableClient) + val physicalReadSchema = + ScanStateRow.getPhysicalDataReadSchema(defaultTableClient, scanState) + val scanFilesIter = scan.getScanFiles(defaultTableClient) + + val readRows = ArrayBuffer[Row]() + while (scanFilesIter.hasNext()) { + val scanFilesBatch = scanFilesIter.next() + val scanFileRows = scanFilesBatch.getRows() + while (scanFileRows.hasNext()) { + val scanFileRow = scanFileRows.next() + val fileStatus = InternalScanFileUtils.getAddFileStatus(scanFileRow) + + val physicalDataIter = defaultTableClient.getParquetHandler.readParquetFiles( + singletonCloseableIterator(fileStatus), + physicalReadSchema, + Optional.empty()) + + val transformedRowsIter = Scan.transformPhysicalData( + defaultTableClient, + scanState, + scanFileRow, + physicalDataIter + ) + + val transformedRows = transformedRowsIter.asScala.toSeq.map(_.getRows).flatMap(_.toSeq) + readRows.appendAll(transformedRows) + } + } + + checkAnswer(readRows.toSeq, df.collect().map(TestRow(_))) + } + } + } + + testReadWithVariant("basic variant") { + spark.range(0, 1, 1, 1).selectExpr( + "parse_json(cast(id as string)) as basic_v", + "named_struct('v', parse_json(cast(id as string))) as struct_v", + "named_struct('v', array(parse_json(cast(id as string)))) as struct_array_v", + "named_struct('v', map('key', parse_json(cast(id as string)))) as struct_map_v", + "named_struct('top', named_struct('v', parse_json(cast(id as string)))) as struct_struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + parse_json(cast(id as string)) + ) as array_v""", + """array( + named_struct('v', parse_json(cast(id as string))), + named_struct('v', parse_json(cast(id as string))), + named_struct('v', parse_json(cast(id as string))) + ) as array_struct_v""", + """array( + map('v', parse_json(cast(id as string))), + map('k1', parse_json(cast(id as string)), 'k2', parse_json(cast(id as string))), + map('v', parse_json(cast(id as string))) + ) as array_map_v""", + "map('test', parse_json(cast(id as string))) as map_value_v", + "map('test', named_struct('v', parse_json(cast(id as string)))) as map_struct_v", + "map(parse_json(cast(id as string)), parse_json(cast(id as string))) as map_key_v" + ) + } + + testReadWithVariant("basic null variant") { + spark.range(0, 10, 1, 1).selectExpr( + "cast(null as variant) basic_v", + "named_struct('v', cast(null as variant)) as struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + null + ) as array_v""", + "map('test', cast(null as variant)) as map_value_v" + ) + } } object ScanSuite { diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala index ae307f91ea7..44889806732 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileReaderSuite.scala @@ -16,6 +16,9 @@ package io.delta.kernel.defaults.internal.parquet import java.math.BigDecimal + +import org.apache.spark.sql.DataFrame + import io.delta.golden.GoldenTableUtils.goldenTableFile import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow} import io.delta.kernel.test.VectorTestUtils diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 661d286a3c9..cf87326ff34 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -16,9 +16,12 @@ package io.delta.kernel.defaults.utils import scala.collection.JavaConverters._ +import scala.collection.mutable.{Seq => MutableSeq} import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} +import org.apache.spark.unsafe.types.VariantVal import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue import io.delta.kernel.types._ import java.sql.Timestamp @@ -44,7 +47,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset} * - ArrayType --> Seq[Any] * - MapType --> Map[Any, Any] * - StructType --> TestRow - * + * - VariantType --> VariantVal * For complex types array and map, the inner elements types should align with this mapping. */ class TestRow(val values: Array[Any]) { @@ -108,9 +111,10 @@ object TestRow { case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) case _: MapType => mapValueToScalaMap(row.getMap(i)) case _: StructType => TestRow(row.getStruct(i)) + case _: VariantType => row.getVariant(i) case _ => throw new UnsupportedOperationException("unrecognized data type") } - }) + }.toSeq) } def apply(row: SparkRow): TestRow = { @@ -133,13 +137,16 @@ object TestRow { case _: sparktypes.BinaryType => obj.asInstanceOf[Array[Byte]] case _: sparktypes.DecimalType => obj.asInstanceOf[java.math.BigDecimal] case arrayType: sparktypes.ArrayType => - obj.asInstanceOf[Seq[Any]] + obj.asInstanceOf[MutableSeq[Any]] .map(decodeCellValue(arrayType.elementType, _)) case mapType: sparktypes.MapType => obj.asInstanceOf[Map[Any, Any]].map { case (k, v) => decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(obj.asInstanceOf[SparkRow]) + case _: sparktypes.VariantType => + val sparkVariant = obj.asInstanceOf[VariantVal] + new DefaultVariantValue(sparkVariant.getValue(), sparkVariant.getMetadata()) case _ => throw new UnsupportedOperationException("unrecognized data type") } } @@ -173,6 +180,9 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(row.getStruct(i)) + case _: sparktypes.VariantType => + val sparkVariant = row.getAs[VariantVal](i) + new DefaultVariantValue(sparkVariant.getValue(), sparkVariant.getMetadata()) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) @@ -204,6 +214,7 @@ object TestRow { TestRow.fromSeq(Seq.range(0, dataType.length()).map { ordinal => getAsTestObject(vector.getChild(ordinal), rowId) }) + case _: VariantType => vector.getVariant(rowId) case _ => throw new UnsupportedOperationException("unrecognized data type") } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 7a11471fff6..d891c56af30 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -73,7 +73,7 @@ trait TestUtils extends Assertions with SQLHelper { while (iter.hasNext) { result.append(iter.next()) } - result + result.toSeq } finally { iter.close() } @@ -117,6 +117,17 @@ trait TestUtils extends Assertions with SQLHelper { lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader } + /** + * Drops table `tableName` after calling `f`. + */ + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + def withGoldenTable(tableName: String)(testFunc: String => Unit): Unit = { val tablePath = GoldenTableUtils.goldenTablePath(tableName) testFunc(tablePath) @@ -153,7 +164,7 @@ trait TestUtils extends Assertions with SQLHelper { // for all primitive types Seq(new Column((basePath :+ field.getName).asJava.toArray(new Array[String](0)))); case _ => Seq.empty - } + }.toSeq } def collectScanFileRows(scan: Scan, tableClient: TableClient = defaultTableClient): Seq[Row] = { @@ -231,7 +242,7 @@ trait TestUtils extends Assertions with SQLHelper { } } } - result + result.toSeq } /** @@ -638,7 +649,8 @@ trait TestUtils extends Assertions with SQLHelper { toSparkType(field.getDataType), field.isNullable ) - }) + }.toSeq) + case VariantType.VARIANT => sparktypes.DataTypes.VariantType } } diff --git a/spark/src/main/scala-spark-3.5/shims/VariantShim.scala b/spark/src/main/scala-spark-3.5/shims/VariantShim.scala new file mode 100644 index 00000000000..d802146eba1 --- /dev/null +++ b/spark/src/main/scala-spark-3.5/shims/VariantShim.scala @@ -0,0 +1,26 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +object VariantShim { + + /** + * Spark's variant type is implemented for Spark 4.0 and is not implemented in Spark 3.5. Thus, + * any Spark 3.5 DataType cannot be a variant type. + */ + def isTypeVariant(dt: DataType): Boolean = false +} diff --git a/spark/src/main/scala-spark-master/shims/VariantShim.scala b/spark/src/main/scala-spark-master/shims/VariantShim.scala new file mode 100644 index 00000000000..63285b58466 --- /dev/null +++ b/spark/src/main/scala-spark-master/shims/VariantShim.scala @@ -0,0 +1,23 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +object VariantShim { + + /** Spark's variant type is only implemented in Spark 4.0 and above. */ + def isTypeVariant(dt: DataType): Boolean = dt.isInstanceOf[VariantType] +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala b/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala index 47b0b6ec8ec..3560d2d771d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/TableFeature.scala @@ -361,7 +361,8 @@ object TableFeature { // Identity columns are under development and only available in testing. IdentityColumnsTableFeature, // managed-commits are under development and only available in testing. - ManagedCommitTableFeature) + ManagedCommitTableFeature, + VariantTypeTableFeature) } val featureMap = features.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap require(features.size == featureMap.size, "Lowercase feature names must not duplicate.") @@ -509,6 +510,14 @@ object IdentityColumnsTableFeature } } +object VariantTypeTableFeature extends ReaderWriterFeature(name = "variantType-dev") + with FeatureAutomaticallyEnabledByMetadata { + override def metadataRequiresFeatureToBeEnabled( + metadata: Metadata, spark: SparkSession): Boolean = { + SchemaUtils.checkForVariantTypeColumnsRecursively(metadata.schema) + } +} + object TimestampNTZTableFeature extends ReaderWriterFeature(name = "timestampNtz") with FeatureAutomaticallyEnabledByMetadata { override def metadataRequiresFeatureToBeEnabled( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index 038a77050d2..83af73ce9c2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -1264,6 +1264,13 @@ def normalizeColumnNamesInDataType( unsupportedDataTypes.toSeq } + /** + * Find VariantType columns in the table schema. + */ + def checkForVariantTypeColumnsRecursively(schema: StructType): Boolean = { + SchemaUtils.typeExistsRecursively(schema)(VariantShim.isTypeVariant(_)) + } + /** * Find TimestampNTZ columns in the table schema. */ diff --git a/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala b/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala new file mode 100644 index 00000000000..a7edc237bca --- /dev/null +++ b/spark/src/test/scala-spark-3.5/shims/DeltaVariantSparkOnlyTests.scala @@ -0,0 +1,19 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +trait DeltaVariantSparkOnlyTests { self: DeltaVariantSuite => } diff --git a/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala b/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala new file mode 100644 index 00000000000..09ddabbee48 --- /dev/null +++ b/spark/src/test/scala-spark-master/shims/DeltaVariantSparkOnlyTests.scala @@ -0,0 +1,109 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.delta.actions.{Protocol, TableFeatureProtocolUtils} +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.types.StructType + +trait DeltaVariantSparkOnlyTests + extends QueryTest + with DeltaSQLCommandTest { self: DeltaVariantSuite => + private def getProtocolForTable(table: String): Protocol = { + val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table)) + deltaLog.unsafeVolatileSnapshot.protocol + } + + test("create a new table with Variant, higher protocol and feature should be picked.") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, v VARIANT) USING DELTA") + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion.withFeature(VariantTypeTableFeature) + ) + } + } + + test("creating a table without Variant should use the usual minimum protocol") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING, i INTEGER) USING DELTA") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + assert( + !deltaLog.unsafeVolatileSnapshot.protocol.isFeatureSupported(VariantTypeTableFeature), + s"Table tbl contains VariantTypeFeature descriptor when its not supposed to" + ) + } + } + + test("add a new Variant column should upgrade to the correct protocol versions") { + withTable("tbl") { + sql("CREATE TABLE tbl(s STRING) USING delta") + assert(getProtocolForTable("tbl") == Protocol(1, 2)) + + // Should throw error + val e = intercept[SparkThrowable] { + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + } + // capture the existing protocol here. + // we will check the error message later in this test as we need to compare the + // expected schema and protocol + val deltaLog = DeltaLog.forTable(spark, TableIdentifier("tbl")) + val currentProtocol = deltaLog.unsafeVolatileSnapshot.protocol + val currentFeatures = currentProtocol.implicitlyAndExplicitlySupportedFeatures + .map(_.name) + .toSeq + .sorted + .mkString(", ") + + // add table feature + sql( + s"ALTER TABLE tbl " + + s"SET TBLPROPERTIES('delta.feature.variantType-dev' = 'supported')" + ) + + sql("ALTER TABLE tbl ADD COLUMN v VARIANT") + + // check previously thrown error message + checkError( + e, + errorClass = "DELTA_FEATURES_REQUIRE_MANUAL_ENABLEMENT", + parameters = Map( + "unsupportedFeatures" -> VariantTypeTableFeature.name, + "supportedFeatures" -> currentFeatures + ) + ) + + sql("INSERT INTO tbl (SELECT 'foo', parse_json(cast(id + 99 as string)) FROM range(1))") + assert(spark.table("tbl").selectExpr("v::int").head == Row(99)) + + assert( + getProtocolForTable("tbl") == + VariantTypeTableFeature.minProtocolVersion + .withFeature(VariantTypeTableFeature) + .withFeature(InvariantsTableFeature) + .withFeature(AppendOnlyTableFeature) + ) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala new file mode 100644 index 00000000000..6a95026d335 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaVariantSuite.scala @@ -0,0 +1,19 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +class DeltaVariantSuite extends DeltaVariantSparkOnlyTests