15
15
*/
16
16
package io .delta .kernel .defaults .internal .expressions ;
17
17
18
+ import java .util .Arrays ;
18
19
import java .util .List ;
19
20
import java .util .Optional ;
20
21
import java .util .stream .Collectors ;
33
34
import static io .delta .kernel .internal .util .ExpressionUtils .getUnaryChild ;
34
35
import static io .delta .kernel .internal .util .Preconditions .checkArgument ;
35
36
36
- import io .delta .kernel .defaults .internal .data .vector .DefaultBooleanVector ;
37
- import io .delta .kernel .defaults .internal .data .vector .DefaultConstantVector ;
37
+ import io .delta .kernel .defaults .internal .data .vector .*;
38
38
import static io .delta .kernel .defaults .internal .expressions .DefaultExpressionUtils .booleanWrapperVector ;
39
39
import static io .delta .kernel .defaults .internal .expressions .DefaultExpressionUtils .childAt ;
40
40
import static io .delta .kernel .defaults .internal .expressions .DefaultExpressionUtils .compare ;
48
48
*/
49
49
public class DefaultExpressionEvaluator implements ExpressionEvaluator {
50
50
private final Expression expression ;
51
+ private final StructType inputSchema ;
51
52
52
53
/**
53
54
* Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and
@@ -68,12 +69,14 @@ public DefaultExpressionEvaluator(
68
69
"Can not create an expression handler returns result of type %s" , outputType );
69
70
throw DeltaErrors .unsupportedExpression (expression , Optional .of (reason ));
70
71
}
72
+ // TODO(richardc-db): Hack to avoid needing to pass the schema into the expression.
73
+ this .inputSchema = inputSchema ;
71
74
this .expression = transformResult .expression ;
72
75
}
73
76
74
77
@ Override
75
78
public ColumnVector eval (ColumnarBatch input ) {
76
- return new ExpressionEvalVisitor (input ).visit (expression );
79
+ return new ExpressionEvalVisitor (input , inputSchema ).visit (expression );
77
80
}
78
81
79
82
@ Override
@@ -278,6 +281,21 @@ ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) {
278
281
);
279
282
}
280
283
284
+ @ Override
285
+ ExpressionTransformResult visitVariantCoalesce (ScalarExpression variantCoalesce ) {
286
+ checkArgument (
287
+ variantCoalesce .getChildren ().size () == 1 ,
288
+ "Expected one input to 'variant_coalesce but received %s" ,
289
+ variantCoalesce .getChildren ().size ()
290
+ );
291
+ Expression transformedVariantInput = visit (childAt (variantCoalesce , 0 )).expression ;
292
+ return new ExpressionTransformResult (
293
+ new ScalarExpression (
294
+ "VARIANT_COALESCE" ,
295
+ Arrays .asList (transformedVariantInput )),
296
+ VariantType .VARIANT );
297
+ }
298
+
281
299
private Predicate validateIsPredicate (
282
300
Expression baseExpression ,
283
301
ExpressionTransformResult result ) {
@@ -318,9 +336,11 @@ private Expression transformBinaryComparator(Predicate predicate) {
318
336
*/
319
337
private static class ExpressionEvalVisitor extends ExpressionVisitor <ColumnVector > {
320
338
private final ColumnarBatch input ;
339
+ private final StructType inputSchema ;
321
340
322
- ExpressionEvalVisitor (ColumnarBatch input ) {
341
+ ExpressionEvalVisitor (ColumnarBatch input , StructType inputSchema ) {
323
342
this .input = input ;
343
+ this .inputSchema = inputSchema ;
324
344
}
325
345
326
346
/*
@@ -558,6 +578,119 @@ ColumnVector visitCoalesce(ScalarExpression coalesce) {
558
578
);
559
579
}
560
580
581
+ @ Override
582
+ ColumnVector visitVariantCoalesce (ScalarExpression variantCoalesce ) {
583
+ return variantCoalesceImpl (
584
+ visit (childAt (variantCoalesce , 0 )),
585
+ inputSchema .at (0 ).getDataType ()
586
+ );
587
+ }
588
+
589
+ private ColumnVector variantCoalesceImpl (ColumnVector inputVec , DataType dt ) {
590
+ if (dt instanceof StructType ) {
591
+ StructType structType = (StructType ) dt ;
592
+ DefaultStructVector structVec = (DefaultStructVector ) inputVec ;
593
+ ColumnVector [] structColVecs = new ColumnVector [structType .length ()];
594
+ for (int i = 0 ; i < structType .length (); i ++) {
595
+ if (structType .at (i ).getDataType () instanceof ArrayType ||
596
+ structType .at (i ).getDataType () instanceof StructType ||
597
+ structType .at (i ).getDataType () instanceof MapType ||
598
+ structType .at (i ).getDataType () instanceof VariantType ) {
599
+ structColVecs [i ] = variantCoalesceImpl (
600
+ structVec .getChild (i ),
601
+ structType .at (i ).getDataType ()
602
+ );
603
+ } else {
604
+ structColVecs [i ] = structVec .getChild (i );
605
+ }
606
+ }
607
+ return new DefaultStructVector (
608
+ structVec .getSize (),
609
+ structType ,
610
+ structVec .getNullability (),
611
+ structColVecs
612
+ );
613
+ }
614
+
615
+ if (dt instanceof ArrayType ) {
616
+ ArrayType arrType = (ArrayType ) dt ;
617
+ DefaultArrayVector arrVec = (DefaultArrayVector ) inputVec ;
618
+
619
+ if (arrType .getElementType () instanceof ArrayType ||
620
+ arrType .getElementType () instanceof StructType ||
621
+ arrType .getElementType () instanceof MapType ||
622
+ arrType .getElementType () instanceof VariantType ) {
623
+ ColumnVector elementVec = variantCoalesceImpl (
624
+ arrVec .getElementVector (),
625
+ arrType .getElementType ()
626
+ );
627
+
628
+ return new DefaultArrayVector (
629
+ arrVec .getSize (),
630
+ arrType ,
631
+ arrVec .getNullability (),
632
+ arrVec .getOffsets (),
633
+ elementVec
634
+ );
635
+ }
636
+ return arrVec ;
637
+ }
638
+
639
+ if (dt instanceof MapType ) {
640
+ MapType mapType = (MapType ) dt ;
641
+ DefaultMapVector mapVec = (DefaultMapVector ) inputVec ;
642
+
643
+ ColumnVector valueVec = mapVec .getValueVector ();
644
+ if (mapType .getValueType () instanceof ArrayType ||
645
+ mapType .getValueType () instanceof StructType ||
646
+ mapType .getValueType () instanceof MapType ||
647
+ mapType .getValueType () instanceof VariantType ) {
648
+ valueVec = variantCoalesceImpl (
649
+ mapVec .getValueVector (),
650
+ mapType .getValueType ()
651
+ );
652
+ }
653
+ ColumnVector keyVec = mapVec .getKeyVector ();
654
+ if (mapType .getKeyType () instanceof ArrayType ||
655
+ mapType .getKeyType () instanceof StructType ||
656
+ mapType .getKeyType () instanceof MapType ||
657
+ mapType .getKeyType () instanceof VariantType ) {
658
+ keyVec = variantCoalesceImpl (
659
+ mapVec .getKeyVector (),
660
+ mapType .getKeyType ()
661
+ );
662
+ }
663
+ return new DefaultMapVector (
664
+ mapVec .getSize (),
665
+ mapType ,
666
+ mapVec .getNullability (),
667
+ mapVec .getOffsets (),
668
+ keyVec ,
669
+ valueVec
670
+ );
671
+ }
672
+
673
+ DefaultStructVector structBackingVariant = (DefaultStructVector ) inputVec ;
674
+ checkArgument (
675
+ structBackingVariant .getChild (0 ).getDataType () instanceof BinaryType ,
676
+ "Expected struct field 0 backing variant to be binary type. Actual: %s" ,
677
+ structBackingVariant .getChild (0 ).getDataType ()
678
+ );
679
+ checkArgument (
680
+ structBackingVariant .getChild (1 ).getDataType () instanceof BinaryType ,
681
+ "Expected struct field 1 backing variant to be binary type. Actual: %s" ,
682
+ structBackingVariant .getChild (1 ).getDataType ()
683
+ );
684
+
685
+ return new DefaultVariantVector (
686
+ structBackingVariant .getSize (),
687
+ VariantType .VARIANT ,
688
+ structBackingVariant .getNullability (),
689
+ structBackingVariant .getChild (0 ),
690
+ structBackingVariant .getChild (1 )
691
+ );
692
+ }
693
+
561
694
/**
562
695
* Utility method to evaluate inputs to the binary input expression. Also validates the
563
696
* evaluated expression result {@link ColumnVector}s are of the same size.
0 commit comments