15
15
*/
16
16
package io .delta .kernel .defaults .internal .expressions ;
17
17
18
+ import java .util .Arrays ;
18
19
import java .util .List ;
19
20
import java .util .Optional ;
20
21
import java .util .stream .Collectors ;
32
33
import static io .delta .kernel .internal .util .ExpressionUtils .getRight ;
33
34
import static io .delta .kernel .internal .util .ExpressionUtils .getUnaryChild ;
34
35
import static io .delta .kernel .internal .util .Preconditions .checkArgument ;
35
- import io .delta .kernel .defaults .internal .data .vector .DefaultBooleanVector ;
36
- import io .delta .kernel .defaults .internal .data .vector .DefaultConstantVector ;
36
+ import io .delta .kernel .defaults .internal .data .vector .*;
37
37
import static io .delta .kernel .defaults .internal .DefaultEngineErrors .unsupportedExpressionException ;
38
38
import static io .delta .kernel .defaults .internal .expressions .DefaultExpressionUtils .*;
39
39
import static io .delta .kernel .defaults .internal .expressions .DefaultExpressionUtils .booleanWrapperVector ;
47
47
*/
48
48
public class DefaultExpressionEvaluator implements ExpressionEvaluator {
49
49
private final Expression expression ;
50
+ private final StructType inputSchema ;
50
51
51
52
/**
52
53
* Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and
@@ -67,12 +68,14 @@ public DefaultExpressionEvaluator(
67
68
"Expression %s does not match expected output type %s" , expression , outputType );
68
69
throw unsupportedExpressionException (expression , reason );
69
70
}
71
+ // TODO(richardc-db): Hack to avoid needing to pass the schema into the expression.
72
+ this .inputSchema = inputSchema ;
70
73
this .expression = transformResult .expression ;
71
74
}
72
75
73
76
@ Override
74
77
public ColumnVector eval (ColumnarBatch input ) {
75
- return new ExpressionEvalVisitor (input ).visit (expression );
78
+ return new ExpressionEvalVisitor (input , inputSchema ).visit (expression );
76
79
}
77
80
78
81
@ Override
@@ -291,6 +294,19 @@ ExpressionTransformResult visitLike(final Predicate like) {
291
294
children .stream ().map (e -> e .outputType ).collect (toList ()));
292
295
293
296
return new ExpressionTransformResult (transformedExpression , BooleanType .BOOLEAN );
297
+
298
+ ExpressionTransformResult visitVariantCoalesce (ScalarExpression variantCoalesce ) {
299
+ checkArgument (
300
+ variantCoalesce .getChildren ().size () == 1 ,
301
+ "Expected one input to 'variant_coalesce but received %s" ,
302
+ variantCoalesce .getChildren ().size ()
303
+ );
304
+ Expression transformedVariantInput = visit (childAt (variantCoalesce , 0 )).expression ;
305
+ return new ExpressionTransformResult (
306
+ new ScalarExpression (
307
+ "VARIANT_COALESCE" ,
308
+ Arrays .asList (transformedVariantInput )),
309
+ VariantType .VARIANT );
294
310
}
295
311
296
312
private Predicate validateIsPredicate (
@@ -333,9 +349,11 @@ private Expression transformBinaryComparator(Predicate predicate) {
333
349
*/
334
350
private static class ExpressionEvalVisitor extends ExpressionVisitor <ColumnVector > {
335
351
private final ColumnarBatch input ;
352
+ private final StructType inputSchema ;
336
353
337
- ExpressionEvalVisitor (ColumnarBatch input ) {
354
+ ExpressionEvalVisitor (ColumnarBatch input , StructType inputSchema ) {
338
355
this .input = input ;
356
+ this .inputSchema = inputSchema ;
339
357
}
340
358
341
359
/*
@@ -575,6 +593,118 @@ ColumnVector visitLike(final Predicate like) {
575
593
.collect (toList ()));
576
594
}
577
595
596
+ ColumnVector visitVariantCoalesce (ScalarExpression variantCoalesce ) {
597
+ return variantCoalesceImpl (
598
+ visit (childAt (variantCoalesce , 0 )),
599
+ inputSchema .at (0 ).getDataType ()
600
+ );
601
+ }
602
+
603
+ private ColumnVector variantCoalesceImpl (ColumnVector inputVec , DataType dt ) {
604
+ if (dt instanceof StructType ) {
605
+ StructType structType = (StructType ) dt ;
606
+ DefaultStructVector structVec = (DefaultStructVector ) inputVec ;
607
+ ColumnVector [] structColVecs = new ColumnVector [structType .length ()];
608
+ for (int i = 0 ; i < structType .length (); i ++) {
609
+ if (structType .at (i ).getDataType () instanceof ArrayType ||
610
+ structType .at (i ).getDataType () instanceof StructType ||
611
+ structType .at (i ).getDataType () instanceof MapType ||
612
+ structType .at (i ).getDataType () instanceof VariantType ) {
613
+ structColVecs [i ] = variantCoalesceImpl (
614
+ structVec .getChild (i ),
615
+ structType .at (i ).getDataType ()
616
+ );
617
+ } else {
618
+ structColVecs [i ] = structVec .getChild (i );
619
+ }
620
+ }
621
+ return new DefaultStructVector (
622
+ structVec .getSize (),
623
+ structType ,
624
+ structVec .getNullability (),
625
+ structColVecs
626
+ );
627
+ }
628
+
629
+ if (dt instanceof ArrayType ) {
630
+ ArrayType arrType = (ArrayType ) dt ;
631
+ DefaultArrayVector arrVec = (DefaultArrayVector ) inputVec ;
632
+
633
+ if (arrType .getElementType () instanceof ArrayType ||
634
+ arrType .getElementType () instanceof StructType ||
635
+ arrType .getElementType () instanceof MapType ||
636
+ arrType .getElementType () instanceof VariantType ) {
637
+ ColumnVector elementVec = variantCoalesceImpl (
638
+ arrVec .getElementVector (),
639
+ arrType .getElementType ()
640
+ );
641
+
642
+ return new DefaultArrayVector (
643
+ arrVec .getSize (),
644
+ arrType ,
645
+ arrVec .getNullability (),
646
+ arrVec .getOffsets (),
647
+ elementVec
648
+ );
649
+ }
650
+ return arrVec ;
651
+ }
652
+
653
+ if (dt instanceof MapType ) {
654
+ MapType mapType = (MapType ) dt ;
655
+ DefaultMapVector mapVec = (DefaultMapVector ) inputVec ;
656
+
657
+ ColumnVector valueVec = mapVec .getValueVector ();
658
+ if (mapType .getValueType () instanceof ArrayType ||
659
+ mapType .getValueType () instanceof StructType ||
660
+ mapType .getValueType () instanceof MapType ||
661
+ mapType .getValueType () instanceof VariantType ) {
662
+ valueVec = variantCoalesceImpl (
663
+ mapVec .getValueVector (),
664
+ mapType .getValueType ()
665
+ );
666
+ }
667
+ ColumnVector keyVec = mapVec .getKeyVector ();
668
+ if (mapType .getKeyType () instanceof ArrayType ||
669
+ mapType .getKeyType () instanceof StructType ||
670
+ mapType .getKeyType () instanceof MapType ||
671
+ mapType .getKeyType () instanceof VariantType ) {
672
+ keyVec = variantCoalesceImpl (
673
+ mapVec .getKeyVector (),
674
+ mapType .getKeyType ()
675
+ );
676
+ }
677
+ return new DefaultMapVector (
678
+ mapVec .getSize (),
679
+ mapType ,
680
+ mapVec .getNullability (),
681
+ mapVec .getOffsets (),
682
+ keyVec ,
683
+ valueVec
684
+ );
685
+ }
686
+
687
+ DefaultStructVector structBackingVariant = (DefaultStructVector ) inputVec ;
688
+ checkArgument (
689
+ structBackingVariant .getChild (0 ).getDataType () instanceof BinaryType ,
690
+ "Expected struct field 0 backing variant to be binary type. Actual: %s" ,
691
+ structBackingVariant .getChild (0 ).getDataType ()
692
+ );
693
+ checkArgument (
694
+ structBackingVariant .getChild (1 ).getDataType () instanceof BinaryType ,
695
+ "Expected struct field 1 backing variant to be binary type. Actual: %s" ,
696
+ structBackingVariant .getChild (1 ).getDataType ()
697
+ );
698
+
699
+ return new DefaultVariantVector (
700
+ structBackingVariant .getSize (),
701
+ VariantType .VARIANT ,
702
+ structBackingVariant .getNullability (),
703
+ structBackingVariant .getChild (0 ),
704
+ structBackingVariant .getChild (1 )
705
+ );
706
+ }
707
+
578
708
/**
579
709
* Utility method to evaluate inputs to the binary input expression. Also validates the
580
710
* evaluated expression result {@link ColumnVector}s are of the same size.
0 commit comments