Skip to content

Commit ddc75dc

Browse files
committed
switch to variant coalesce expression
1 parent 45f0b3a commit ddc75dc

File tree

12 files changed

+354
-184
lines changed

12 files changed

+354
-184
lines changed

kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import io.delta.kernel.internal.util.ColumnMapping;
3737
import io.delta.kernel.internal.util.PartitionUtils;
3838
import io.delta.kernel.internal.util.Tuple2;
39+
import io.delta.kernel.internal.util.VariantUtils;
3940

4041
/**
4142
* Represents a scan of a Delta table.
@@ -194,6 +195,13 @@ public FilteredColumnarBatch next() {
194195
nextDataBatch = nextDataBatch.withDeletedColumnAt(rowIndexOrdinal);
195196
}
196197

198+
// Transform physical variant columns (struct of binaries) into logical variant
199+
// columns.
200+
nextDataBatch = VariantUtils.withVariantColumns(
201+
tableClient.getExpressionHandler(),
202+
nextDataBatch
203+
);
204+
197205
// Add partition columns
198206
nextDataBatch =
199207
PartitionUtils.withPartitionColumns(
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (2024) The Delta Lake Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.delta.kernel.internal.util;
18+
19+
import java.util.Arrays;
20+
21+
import io.delta.kernel.client.ExpressionHandler;
22+
import io.delta.kernel.data.ColumnVector;
23+
import io.delta.kernel.data.ColumnarBatch;
24+
import io.delta.kernel.expressions.*;
25+
import io.delta.kernel.types.*;
26+
27+
public class VariantUtils {
28+
public static ColumnarBatch withVariantColumns(
29+
ExpressionHandler expressionHandler,
30+
ColumnarBatch dataBatch) {
31+
for (int i = 0; i < dataBatch.getSchema().length(); i++) {
32+
StructField field = dataBatch.getSchema().at(i);
33+
if (!(field.getDataType() instanceof StructType) &&
34+
!(field.getDataType() instanceof ArrayType) &&
35+
!(field.getDataType() instanceof MapType) &&
36+
(field.getDataType() != VariantType.VARIANT ||
37+
dataBatch.getColumnVector(i).getDataType() == VariantType.VARIANT)) {
38+
continue;
39+
}
40+
41+
ExpressionEvaluator evaluator = expressionHandler.getEvaluator(
42+
// Field here is variant type if its actually a variant.
43+
// TODO: probably better to pass in the schema as an argument
44+
// so the schema is enforced at the expression level. Need to pass in a literal
45+
// schema
46+
new StructType().add(field),
47+
new ScalarExpression(
48+
"variant_coalesce",
49+
Arrays.asList(new Column(field.getName()))
50+
),
51+
VariantType.VARIANT
52+
);
53+
54+
// TODO: don't need to pass in the entire batch.
55+
ColumnVector variantCol = evaluator.eval(dataBatch);
56+
// TODO: make a more efficient way to do this.
57+
dataBatch =
58+
dataBatch.withDeletedColumnAt(i).withNewColumn(i, field, variantCol);
59+
}
60+
return dataBatch;
61+
}
62+
63+
private static ColumnVector[] getColumnBatchVectors(ColumnarBatch batch) {
64+
ColumnVector[] res = new ColumnVector[batch.getSchema().length()];
65+
for (int i = 0; i < batch.getSchema().length(); i++) {
66+
res[i] = batch.getColumnVector(i);
67+
}
68+
return res;
69+
}
70+
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ public boolean isNullAt(int rowId) {
7979
return nullability.get()[rowId];
8080
}
8181

82+
public Optional<boolean[]> getNullability() {
83+
return nullability;
84+
}
85+
8286
@Override
8387
public boolean getBoolean(int rowId) {
8488
throw unsupportedDataAccessException("boolean");

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,12 @@ public ColumnVector getElements() {
8888
}
8989
};
9090
}
91+
92+
public ColumnVector getElementVector() {
93+
return elementVector;
94+
}
95+
96+
public int[] getOffsets() {
97+
return offsets;
98+
}
9199
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,16 @@ public ColumnVector getValues() {
9898
}
9999
};
100100
}
101+
102+
public ColumnVector getKeyVector() {
103+
return keyVector;
104+
}
105+
106+
public ColumnVector getValueVector() {
107+
return valueVector;
108+
}
109+
110+
public int[] getOffsets() {
111+
return offsets;
112+
}
101113
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package io.delta.kernel.defaults.internal.expressions;
1717

18+
import java.util.Arrays;
1819
import java.util.List;
1920
import java.util.Optional;
2021
import java.util.stream.Collectors;
@@ -33,8 +34,7 @@
3334
import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild;
3435
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
3536

36-
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
37-
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
37+
import io.delta.kernel.defaults.internal.data.vector.*;
3838
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector;
3939
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt;
4040
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare;
@@ -48,6 +48,7 @@
4848
*/
4949
public class DefaultExpressionEvaluator implements ExpressionEvaluator {
5050
private final Expression expression;
51+
private final StructType inputSchema;
5152

5253
/**
5354
* Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and
@@ -68,12 +69,14 @@ public DefaultExpressionEvaluator(
6869
"Can not create an expression handler returns result of type %s", outputType);
6970
throw DeltaErrors.unsupportedExpression(expression, Optional.of(reason));
7071
}
72+
// TODO(richardc-db): Hack to avoid needing to pass the schema into the expression.
73+
this.inputSchema = inputSchema;
7174
this.expression = transformResult.expression;
7275
}
7376

7477
@Override
7578
public ColumnVector eval(ColumnarBatch input) {
76-
return new ExpressionEvalVisitor(input).visit(expression);
79+
return new ExpressionEvalVisitor(input, inputSchema).visit(expression);
7780
}
7881

7982
@Override
@@ -278,6 +281,21 @@ ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) {
278281
);
279282
}
280283

284+
@Override
285+
ExpressionTransformResult visitVariantCoalesce(ScalarExpression variantCoalesce) {
286+
checkArgument(
287+
variantCoalesce.getChildren().size() == 1,
288+
"Expected one input to 'variant_coalesce but received %s",
289+
variantCoalesce.getChildren().size()
290+
);
291+
Expression transformedVariantInput = visit(childAt(variantCoalesce, 0)).expression;
292+
return new ExpressionTransformResult(
293+
new ScalarExpression(
294+
"VARIANT_COALESCE",
295+
Arrays.asList(transformedVariantInput)),
296+
VariantType.VARIANT);
297+
}
298+
281299
private Predicate validateIsPredicate(
282300
Expression baseExpression,
283301
ExpressionTransformResult result) {
@@ -318,9 +336,11 @@ private Expression transformBinaryComparator(Predicate predicate) {
318336
*/
319337
private static class ExpressionEvalVisitor extends ExpressionVisitor<ColumnVector> {
320338
private final ColumnarBatch input;
339+
private final StructType inputSchema;
321340

322-
ExpressionEvalVisitor(ColumnarBatch input) {
341+
ExpressionEvalVisitor(ColumnarBatch input, StructType inputSchema) {
323342
this.input = input;
343+
this.inputSchema = inputSchema;
324344
}
325345

326346
/*
@@ -558,6 +578,119 @@ ColumnVector visitCoalesce(ScalarExpression coalesce) {
558578
);
559579
}
560580

581+
@Override
582+
ColumnVector visitVariantCoalesce(ScalarExpression variantCoalesce) {
583+
return variantCoalesceImpl(
584+
visit(childAt(variantCoalesce, 0)),
585+
inputSchema.at(0).getDataType()
586+
);
587+
}
588+
589+
private ColumnVector variantCoalesceImpl(ColumnVector inputVec, DataType dt) {
590+
if (dt instanceof StructType) {
591+
StructType structType = (StructType) dt;
592+
DefaultStructVector structVec = (DefaultStructVector) inputVec;
593+
ColumnVector[] structColVecs = new ColumnVector[structType.length()];
594+
for (int i = 0; i < structType.length(); i++) {
595+
if (structType.at(i).getDataType() instanceof ArrayType ||
596+
structType.at(i).getDataType() instanceof StructType ||
597+
structType.at(i).getDataType() instanceof MapType ||
598+
structType.at(i).getDataType() instanceof VariantType) {
599+
structColVecs[i] = variantCoalesceImpl(
600+
structVec.getChild(i),
601+
structType.at(i).getDataType()
602+
);
603+
} else {
604+
structColVecs[i] = structVec.getChild(i);
605+
}
606+
}
607+
return new DefaultStructVector(
608+
structVec.getSize(),
609+
structType,
610+
structVec.getNullability(),
611+
structColVecs
612+
);
613+
}
614+
615+
if (dt instanceof ArrayType) {
616+
ArrayType arrType = (ArrayType) dt;
617+
DefaultArrayVector arrVec = (DefaultArrayVector) inputVec;
618+
619+
if (arrType.getElementType() instanceof ArrayType ||
620+
arrType.getElementType() instanceof StructType ||
621+
arrType.getElementType() instanceof MapType ||
622+
arrType.getElementType() instanceof VariantType) {
623+
ColumnVector elementVec = variantCoalesceImpl(
624+
arrVec.getElementVector(),
625+
arrType.getElementType()
626+
);
627+
628+
return new DefaultArrayVector(
629+
arrVec.getSize(),
630+
arrType,
631+
arrVec.getNullability(),
632+
arrVec.getOffsets(),
633+
elementVec
634+
);
635+
}
636+
return arrVec;
637+
}
638+
639+
if (dt instanceof MapType) {
640+
MapType mapType = (MapType) dt;
641+
DefaultMapVector mapVec = (DefaultMapVector) inputVec;
642+
643+
ColumnVector valueVec = mapVec.getValueVector();
644+
if (mapType.getValueType() instanceof ArrayType ||
645+
mapType.getValueType() instanceof StructType ||
646+
mapType.getValueType() instanceof MapType ||
647+
mapType.getValueType() instanceof VariantType) {
648+
valueVec = variantCoalesceImpl(
649+
mapVec.getValueVector(),
650+
mapType.getValueType()
651+
);
652+
}
653+
ColumnVector keyVec = mapVec.getKeyVector();
654+
if (mapType.getKeyType() instanceof ArrayType ||
655+
mapType.getKeyType() instanceof StructType ||
656+
mapType.getKeyType() instanceof MapType ||
657+
mapType.getKeyType() instanceof VariantType) {
658+
keyVec = variantCoalesceImpl(
659+
mapVec.getKeyVector(),
660+
mapType.getKeyType()
661+
);
662+
}
663+
return new DefaultMapVector(
664+
mapVec.getSize(),
665+
mapType,
666+
mapVec.getNullability(),
667+
mapVec.getOffsets(),
668+
keyVec,
669+
valueVec
670+
);
671+
}
672+
673+
DefaultStructVector structBackingVariant = (DefaultStructVector) inputVec;
674+
checkArgument(
675+
structBackingVariant.getChild(0).getDataType() instanceof BinaryType,
676+
"Expected struct field 0 backing variant to be binary type. Actual: %s",
677+
structBackingVariant.getChild(0).getDataType()
678+
);
679+
checkArgument(
680+
structBackingVariant.getChild(1).getDataType() instanceof BinaryType,
681+
"Expected struct field 1 backing variant to be binary type. Actual: %s",
682+
structBackingVariant.getChild(1).getDataType()
683+
);
684+
685+
return new DefaultVariantVector(
686+
structBackingVariant.getSize(),
687+
VariantType.VARIANT,
688+
structBackingVariant.getNullability(),
689+
structBackingVariant.getChild(0),
690+
structBackingVariant.getChild(1)
691+
);
692+
}
693+
561694
/**
562695
* Utility method to evaluate inputs to the binary input expression. Also validates the
563696
* evaluated expression result {@link ColumnVector}s are of the same size.

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ abstract class ExpressionVisitor<R> {
5959

6060
abstract R visitCoalesce(ScalarExpression ifNull);
6161

62+
abstract R visitVariantCoalesce(ScalarExpression variantCoalesce);
63+
6264
final R visit(Expression expression) {
6365
if (expression instanceof PartitionValueExpression) {
6466
return visitPartitionValue((PartitionValueExpression) expression);
@@ -105,6 +107,8 @@ private R visitScalarExpression(ScalarExpression expression) {
105107
return visitIsNull(new Predicate(name, children));
106108
case "COALESCE":
107109
return visitCoalesce(expression);
110+
case "VARIANT_COALESCE":
111+
return visitVariantCoalesce(expression);
108112
default:
109113
throw new UnsupportedOperationException(
110114
String.format("Scalar expression `%s` is not supported.", name));

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ public static Converter createConverter(
8383
} else if (typeFromClient instanceof TimestampNTZType) {
8484
return createTimestampNtzConverter(initialBatchSize, typeFromFile);
8585
} else if (typeFromClient instanceof VariantType) {
86-
return new VariantColumnReader(initialBatchSize);
86+
return new RowColumnReader(
87+
initialBatchSize,
88+
new StructType()
89+
.add("value", BinaryType.BINARY, false)
90+
.add("metadata", BinaryType.BINARY, false),
91+
(GroupType) typeFromFile);
8792
}
8893

8994
throw new UnsupportedOperationException(typeFromClient + " is not supported");

0 commit comments

Comments
 (0)