Skip to content

Commit 33714ce

Browse files
committed
switch to variant coalesce expression
1 parent 3f221ac commit 33714ce

File tree

12 files changed

+353
-186
lines changed

12 files changed

+353
-186
lines changed

kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import io.delta.kernel.internal.util.ColumnMapping;
3737
import io.delta.kernel.internal.util.PartitionUtils;
3838
import io.delta.kernel.internal.util.Tuple2;
39+
import io.delta.kernel.internal.util.VariantUtils;
3940

4041
/**
4142
* Represents a scan of a Delta table.
@@ -202,6 +203,13 @@ public FilteredColumnarBatch next() {
202203
nextDataBatch = nextDataBatch.withDeletedColumnAt(rowIndexOrdinal);
203204
}
204205

206+
// Transform physical variant columns (struct of binaries) into logical variant
207+
// columns.
208+
nextDataBatch = VariantUtils.withVariantColumns(
209+
tableClient.getExpressionHandler(),
210+
nextDataBatch
211+
);
212+
205213
// Add partition columns
206214
nextDataBatch =
207215
PartitionUtils.withPartitionColumns(
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (2024) The Delta Lake Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.delta.kernel.internal.util;
18+
19+
import java.util.Arrays;
20+
21+
import io.delta.kernel.client.ExpressionHandler;
22+
import io.delta.kernel.data.ColumnVector;
23+
import io.delta.kernel.data.ColumnarBatch;
24+
import io.delta.kernel.expressions.*;
25+
import io.delta.kernel.types.*;
26+
27+
public class VariantUtils {
28+
public static ColumnarBatch withVariantColumns(
29+
ExpressionHandler expressionHandler,
30+
ColumnarBatch dataBatch) {
31+
for (int i = 0; i < dataBatch.getSchema().length(); i++) {
32+
StructField field = dataBatch.getSchema().at(i);
33+
if (!(field.getDataType() instanceof StructType) &&
34+
!(field.getDataType() instanceof ArrayType) &&
35+
!(field.getDataType() instanceof MapType) &&
36+
(field.getDataType() != VariantType.VARIANT ||
37+
dataBatch.getColumnVector(i).getDataType() == VariantType.VARIANT)) {
38+
continue;
39+
}
40+
41+
ExpressionEvaluator evaluator = expressionHandler.getEvaluator(
42+
// Field here is variant type if its actually a variant.
43+
// TODO: probably better to pass in the schema as an argument
44+
// so the schema is enforced at the expression level. Need to pass in a literal
45+
// schema
46+
new StructType().add(field),
47+
new ScalarExpression(
48+
"variant_coalesce",
49+
Arrays.asList(new Column(field.getName()))
50+
),
51+
VariantType.VARIANT
52+
);
53+
54+
// TODO: don't need to pass in the entire batch.
55+
ColumnVector variantCol = evaluator.eval(dataBatch);
56+
// TODO: make a more efficient way to do this.
57+
dataBatch =
58+
dataBatch.withDeletedColumnAt(i).withNewColumn(i, field, variantCol);
59+
}
60+
return dataBatch;
61+
}
62+
63+
private static ColumnVector[] getColumnBatchVectors(ColumnarBatch batch) {
64+
ColumnVector[] res = new ColumnVector[batch.getSchema().length()];
65+
for (int i = 0; i < batch.getSchema().length(); i++) {
66+
res[i] = batch.getColumnVector(i);
67+
}
68+
return res;
69+
}
70+
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ public boolean isNullAt(int rowId) {
7979
return nullability.get()[rowId];
8080
}
8181

82+
public Optional<boolean[]> getNullability() {
83+
return nullability;
84+
}
85+
8286
@Override
8387
public boolean getBoolean(int rowId) {
8488
throw unsupportedDataAccessException("boolean");

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,12 @@ public ColumnVector getElements() {
8888
}
8989
};
9090
}
91+
92+
public ColumnVector getElementVector() {
93+
return elementVector;
94+
}
95+
96+
public int[] getOffsets() {
97+
return offsets;
98+
}
9199
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,16 @@ public ColumnVector getValues() {
9898
}
9999
};
100100
}
101+
102+
public ColumnVector getKeyVector() {
103+
return keyVector;
104+
}
105+
106+
public ColumnVector getValueVector() {
107+
return valueVector;
108+
}
109+
110+
public int[] getOffsets() {
111+
return offsets;
112+
}
101113
}

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package io.delta.kernel.defaults.internal.expressions;
1717

18+
import java.util.Arrays;
1819
import java.util.List;
1920
import java.util.Optional;
2021
import java.util.stream.Collectors;
@@ -32,8 +33,7 @@
3233
import static io.delta.kernel.internal.util.ExpressionUtils.getRight;
3334
import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild;
3435
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
35-
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
36-
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
36+
import io.delta.kernel.defaults.internal.data.vector.*;
3737
import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
3838
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;
3939
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector;
@@ -47,6 +47,7 @@
4747
*/
4848
public class DefaultExpressionEvaluator implements ExpressionEvaluator {
4949
private final Expression expression;
50+
private final StructType inputSchema;
5051

5152
/**
5253
* Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and
@@ -67,12 +68,14 @@ public DefaultExpressionEvaluator(
6768
"Expression %s does not match expected output type %s", expression, outputType);
6869
throw unsupportedExpressionException(expression, reason);
6970
}
71+
// TODO(richardc-db): Hack to avoid needing to pass the schema into the expression.
72+
this.inputSchema = inputSchema;
7073
this.expression = transformResult.expression;
7174
}
7275

7376
@Override
7477
public ColumnVector eval(ColumnarBatch input) {
75-
return new ExpressionEvalVisitor(input).visit(expression);
78+
return new ExpressionEvalVisitor(input, inputSchema).visit(expression);
7679
}
7780

7881
@Override
@@ -291,6 +294,19 @@ ExpressionTransformResult visitLike(final Predicate like) {
291294
children.stream().map(e -> e.outputType).collect(toList()));
292295

293296
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
297+
298+
ExpressionTransformResult visitVariantCoalesce(ScalarExpression variantCoalesce) {
299+
checkArgument(
300+
variantCoalesce.getChildren().size() == 1,
301+
"Expected one input to 'variant_coalesce but received %s",
302+
variantCoalesce.getChildren().size()
303+
);
304+
Expression transformedVariantInput = visit(childAt(variantCoalesce, 0)).expression;
305+
return new ExpressionTransformResult(
306+
new ScalarExpression(
307+
"VARIANT_COALESCE",
308+
Arrays.asList(transformedVariantInput)),
309+
VariantType.VARIANT);
294310
}
295311

296312
private Predicate validateIsPredicate(
@@ -333,9 +349,11 @@ private Expression transformBinaryComparator(Predicate predicate) {
333349
*/
334350
private static class ExpressionEvalVisitor extends ExpressionVisitor<ColumnVector> {
335351
private final ColumnarBatch input;
352+
private final StructType inputSchema;
336353

337-
ExpressionEvalVisitor(ColumnarBatch input) {
354+
ExpressionEvalVisitor(ColumnarBatch input, StructType inputSchema) {
338355
this.input = input;
356+
this.inputSchema = inputSchema;
339357
}
340358

341359
/*
@@ -575,6 +593,118 @@ ColumnVector visitLike(final Predicate like) {
575593
.collect(toList()));
576594
}
577595

596+
ColumnVector visitVariantCoalesce(ScalarExpression variantCoalesce) {
597+
return variantCoalesceImpl(
598+
visit(childAt(variantCoalesce, 0)),
599+
inputSchema.at(0).getDataType()
600+
);
601+
}
602+
603+
private ColumnVector variantCoalesceImpl(ColumnVector inputVec, DataType dt) {
604+
if (dt instanceof StructType) {
605+
StructType structType = (StructType) dt;
606+
DefaultStructVector structVec = (DefaultStructVector) inputVec;
607+
ColumnVector[] structColVecs = new ColumnVector[structType.length()];
608+
for (int i = 0; i < structType.length(); i++) {
609+
if (structType.at(i).getDataType() instanceof ArrayType ||
610+
structType.at(i).getDataType() instanceof StructType ||
611+
structType.at(i).getDataType() instanceof MapType ||
612+
structType.at(i).getDataType() instanceof VariantType) {
613+
structColVecs[i] = variantCoalesceImpl(
614+
structVec.getChild(i),
615+
structType.at(i).getDataType()
616+
);
617+
} else {
618+
structColVecs[i] = structVec.getChild(i);
619+
}
620+
}
621+
return new DefaultStructVector(
622+
structVec.getSize(),
623+
structType,
624+
structVec.getNullability(),
625+
structColVecs
626+
);
627+
}
628+
629+
if (dt instanceof ArrayType) {
630+
ArrayType arrType = (ArrayType) dt;
631+
DefaultArrayVector arrVec = (DefaultArrayVector) inputVec;
632+
633+
if (arrType.getElementType() instanceof ArrayType ||
634+
arrType.getElementType() instanceof StructType ||
635+
arrType.getElementType() instanceof MapType ||
636+
arrType.getElementType() instanceof VariantType) {
637+
ColumnVector elementVec = variantCoalesceImpl(
638+
arrVec.getElementVector(),
639+
arrType.getElementType()
640+
);
641+
642+
return new DefaultArrayVector(
643+
arrVec.getSize(),
644+
arrType,
645+
arrVec.getNullability(),
646+
arrVec.getOffsets(),
647+
elementVec
648+
);
649+
}
650+
return arrVec;
651+
}
652+
653+
if (dt instanceof MapType) {
654+
MapType mapType = (MapType) dt;
655+
DefaultMapVector mapVec = (DefaultMapVector) inputVec;
656+
657+
ColumnVector valueVec = mapVec.getValueVector();
658+
if (mapType.getValueType() instanceof ArrayType ||
659+
mapType.getValueType() instanceof StructType ||
660+
mapType.getValueType() instanceof MapType ||
661+
mapType.getValueType() instanceof VariantType) {
662+
valueVec = variantCoalesceImpl(
663+
mapVec.getValueVector(),
664+
mapType.getValueType()
665+
);
666+
}
667+
ColumnVector keyVec = mapVec.getKeyVector();
668+
if (mapType.getKeyType() instanceof ArrayType ||
669+
mapType.getKeyType() instanceof StructType ||
670+
mapType.getKeyType() instanceof MapType ||
671+
mapType.getKeyType() instanceof VariantType) {
672+
keyVec = variantCoalesceImpl(
673+
mapVec.getKeyVector(),
674+
mapType.getKeyType()
675+
);
676+
}
677+
return new DefaultMapVector(
678+
mapVec.getSize(),
679+
mapType,
680+
mapVec.getNullability(),
681+
mapVec.getOffsets(),
682+
keyVec,
683+
valueVec
684+
);
685+
}
686+
687+
DefaultStructVector structBackingVariant = (DefaultStructVector) inputVec;
688+
checkArgument(
689+
structBackingVariant.getChild(0).getDataType() instanceof BinaryType,
690+
"Expected struct field 0 backing variant to be binary type. Actual: %s",
691+
structBackingVariant.getChild(0).getDataType()
692+
);
693+
checkArgument(
694+
structBackingVariant.getChild(1).getDataType() instanceof BinaryType,
695+
"Expected struct field 1 backing variant to be binary type. Actual: %s",
696+
structBackingVariant.getChild(1).getDataType()
697+
);
698+
699+
return new DefaultVariantVector(
700+
structBackingVariant.getSize(),
701+
VariantType.VARIANT,
702+
structBackingVariant.getNullability(),
703+
structBackingVariant.getChild(0),
704+
structBackingVariant.getChild(1)
705+
);
706+
}
707+
578708
/**
579709
* Utility method to evaluate inputs to the binary input expression. Also validates the
580710
* evaluated expression result {@link ColumnVector}s are of the same size.

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ abstract class ExpressionVisitor<R> {
6161

6262
abstract R visitLike(Predicate predicate);
6363

64+
abstract R visitVariantCoalesce(ScalarExpression variantCoalesce);
65+
6466
final R visit(Expression expression) {
6567
if (expression instanceof PartitionValueExpression) {
6668
return visitPartitionValue((PartitionValueExpression) expression);
@@ -109,6 +111,8 @@ private R visitScalarExpression(ScalarExpression expression) {
109111
return visitCoalesce(expression);
110112
case "LIKE":
111113
return visitLike(new Predicate(name, children));
114+
case "VARIANT_COALESCE":
115+
return visitVariantCoalesce(expression);
112116
default:
113117
throw new UnsupportedOperationException(
114118
String.format("Scalar expression `%s` is not supported.", name));

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ public static Converter createConverter(
8484
return createTimestampConverter(initialBatchSize, typeFromFile,
8585
TimestampNTZType.TIMESTAMP_NTZ);
8686
} else if (typeFromClient instanceof VariantType) {
87-
return new VariantColumnReader(initialBatchSize);
87+
return new RowColumnReader(
88+
initialBatchSize,
89+
new StructType()
90+
.add("value", BinaryType.BINARY, false)
91+
.add("metadata", BinaryType.BINARY, false),
92+
(GroupType) typeFromFile);
8893
}
8994

9095
throw new UnsupportedOperationException(typeFromClient + " is not supported");

0 commit comments

Comments
 (0)