Skip to content

Commit 8845bbf

Browse files
authored
feat: support Nested Structs (#579)
1 parent eecf80e commit 8845bbf

File tree

23 files changed

+487
-84
lines changed

23 files changed

+487
-84
lines changed

core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
151151
return visitFallback(expr, context);
152152
}
153153

154+
@Override
155+
public O visit(Expression.NestedStruct expr, C context) throws E {
156+
return visitFallback(expr, context);
157+
}
158+
154159
@Override
155160
public O visit(Expression.Switch expr, C context) throws E {
156161
return visitFallback(expr, context);

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ default boolean nullable() {
3232
}
3333
}
3434

35+
interface Nested extends Expression {
36+
@Value.Default
37+
default boolean nullable() {
38+
return false;
39+
}
40+
}
41+
3542
<R, C extends VisitationContext, E extends Throwable> R accept(
3643
ExpressionVisitor<R, C, E> visitor, C context) throws E;
3744

@@ -662,6 +669,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
662669
}
663670
}
664671

672+
@Value.Immutable
673+
abstract class NestedStruct implements Nested {
674+
public abstract List<Expression> fields();
675+
676+
@Override
677+
public Type getType() {
678+
return Type.withNullability(nullable())
679+
.struct(
680+
fields().stream()
681+
.map(Expression::getType)
682+
.collect(java.util.stream.Collectors.toList()));
683+
}
684+
685+
public static ImmutableExpression.NestedStruct.Builder builder() {
686+
return ImmutableExpression.NestedStruct.builder();
687+
}
688+
689+
@Override
690+
public <R, C extends VisitationContext, E extends Throwable> R accept(
691+
ExpressionVisitor<R, C, E> visitor, C context) throws E {
692+
return visitor.visit(this, context);
693+
}
694+
}
695+
665696
@Value.Immutable
666697
abstract class UserDefinedLiteral implements Literal {
667698
public abstract ByteString value();

core/src/main/java/io/substrait/expression/ExpressionCreator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,15 @@ public static Expression.StructLiteral struct(
286286
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
287287
}
288288

289+
public static Expression.NestedStruct nestedStruct(
290+
boolean nullable, Iterable<Expression> fields) {
291+
return Expression.NestedStruct.builder().nullable(nullable).addAllFields(fields).build();
292+
}
293+
294+
public static Expression.NestedStruct nestedStruct(boolean nullable, Expression... fields) {
295+
return Expression.NestedStruct.builder().nullable(nullable).addFields(fields).build();
296+
}
297+
289298
public static Expression.UserDefinedLiteral userDefinedLiteral(
290299
boolean nullable, String urn, String name, Any value) {
291300
return Expression.UserDefinedLiteral.builder()

core/src/main/java/io/substrait/expression/ExpressionVisitor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
6262

6363
R visit(Expression.StructLiteral expr, C context) throws E;
6464

65+
R visit(Expression.NestedStruct expr, C context) throws E;
66+
6567
R visit(Expression.UserDefinedLiteral expr, C context) throws E;
6668

6769
R visit(Expression.Switch expr, C context) throws E;

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ private Expression lit(Consumer<Expression.Literal.Builder> consumer) {
7575
return Expression.newBuilder().setLiteral(builder).build();
7676
}
7777

78+
private Expression nested(Consumer<Expression.Nested.Builder> consumer) {
79+
Expression.Nested.Builder builder = Expression.Nested.newBuilder();
80+
consumer.accept(builder);
81+
return Expression.newBuilder().setNested(builder).build();
82+
}
83+
7884
@Override
7985
public Expression visit(
8086
io.substrait.expression.Expression.BoolLiteral expr, EmptyVisitationContext context) {
@@ -357,6 +363,18 @@ public Expression visit(
357363
});
358364
}
359365

366+
@Override
367+
public Expression visit(
368+
io.substrait.expression.Expression.NestedStruct expr, EmptyVisitationContext context) {
369+
return nested(
370+
bldr -> {
371+
List<Expression> values =
372+
expr.fields().stream().map(this::toProto).collect(Collectors.toList());
373+
bldr.setStruct(Expression.Nested.Struct.newBuilder().addAllFields(values))
374+
.setNullable(expr.nullable());
375+
});
376+
}
377+
360378
@Override
361379
public Expression visit(
362380
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,18 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
502502
}
503503
}
504504

505+
public Expression.StructLiteral from(io.substrait.proto.Expression.Literal.Struct struct) {
506+
return Expression.StructLiteral.builder()
507+
.fields(struct.getFieldsList().stream().map(this::from).collect(Collectors.toList()))
508+
.build();
509+
}
510+
511+
public Expression.NestedStruct from(io.substrait.proto.Expression.Nested.Struct struct) {
512+
return Expression.NestedStruct.builder()
513+
.fields(struct.getFieldsList().stream().map(this::from).collect(Collectors.toList()))
514+
.build();
515+
}
516+
505517
private static List<FunctionArg> fromFunctionArgumentList(
506518
int argumentsCount,
507519
FunctionArg.ProtoFrom argVisitor,

core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,15 @@ public Optional<Expression> visit(Expression.StructLiteral expr, EmptyVisitation
203203
return visitLiteral(expr);
204204
}
205205

206+
@Override
207+
public Optional<Expression> visit(Expression.NestedStruct expr, EmptyVisitationContext context)
208+
throws E {
209+
Optional<List<Expression>> expressions = visitExprList(expr.fields(), context);
210+
return expressions.map(
211+
expressionList ->
212+
Expression.NestedStruct.builder().from(expr).fields(expressionList).build());
213+
}
214+
206215
@Override
207216
public Optional<Expression> visit(
208217
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ public Rel from(io.substrait.proto.Rel rel) {
188188
protected Rel newRead(ReadRel rel) {
189189
if (rel.hasVirtualTable()) {
190190
ReadRel.VirtualTable virtualTable = rel.getVirtualTable();
191-
if (virtualTable.getValuesCount() == 0) {
191+
if (virtualTable.getValuesCount() == 0 && virtualTable.getExpressionsCount() == 0) {
192192
return newEmptyScan(rel);
193193
} else {
194194
return newVirtualTable(rel);
@@ -596,20 +596,50 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) {
596596
return builder.build();
597597
}
598598

599+
/**
600+
* Converts StructLiteral instances to NestedStruct for VirtualTableScan. This is a convenience
601+
* method for migrating from the legacy StructLiteral-based VirtualTable API to the new
602+
* NestedStruct-based API.
603+
*
604+
* @param nullable whether the resulting NestedStruct instances should be nullable
605+
* @param structs the StructLiteral instances to convert
606+
* @return a list of NestedStruct instances with the same field structure
607+
*/
608+
private static List<Expression.NestedStruct> nestedStruct(
609+
boolean nullable, Expression.StructLiteral... structs) {
610+
List<Expression.NestedStruct> nestedStructs = new ArrayList<>();
611+
for (Expression.StructLiteral struct : structs) {
612+
nestedStructs.add(
613+
Expression.NestedStruct.builder()
614+
.nullable(nullable)
615+
.addAllFields(struct.fields())
616+
.build());
617+
}
618+
return nestedStructs;
619+
}
620+
599621
protected VirtualTableScan newVirtualTable(ReadRel rel) {
600622
ReadRel.VirtualTable virtualTable = rel.getVirtualTable();
623+
// If both values and expressions are set, raise an error
624+
if (virtualTable.getValuesCount() > 0 && virtualTable.getExpressionsCount() > 0) {
625+
throw new IllegalArgumentException(
626+
"VirtualTable cannot have both values and expressions set");
627+
}
601628
NamedStruct virtualTableSchema = newNamedStruct(rel);
602629
ProtoExpressionConverter converter =
603630
new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this);
604-
List<Expression.StructLiteral> structLiterals = new ArrayList<>(virtualTable.getValuesCount());
631+
632+
List<Expression.NestedStruct> expressions =
633+
new ArrayList<>(virtualTable.getValuesCount() + virtualTable.getExpressionsCount());
634+
635+
// We cannot have a null row in VirtualTable, therefore we set the nullability to false
636+
// nullability is also not supported at the Expression.Nested.Struct level
605637
for (io.substrait.proto.Expression.Literal.Struct struct : virtualTable.getValuesList()) {
606-
structLiterals.add(
607-
Expression.StructLiteral.builder()
608-
.fields(
609-
struct.getFieldsList().stream()
610-
.map(converter::from)
611-
.collect(java.util.stream.Collectors.toList()))
612-
.build());
638+
expressions.addAll(nestedStruct(false, converter.from(struct)));
639+
}
640+
641+
for (io.substrait.proto.Expression.Nested.Struct expr : virtualTable.getExpressionsList()) {
642+
expressions.add(converter.from(expr));
613643
}
614644

615645
ImmutableVirtualTableScan.Builder builder =
@@ -619,7 +649,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) {
619649
rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null))
620650
.filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null))
621651
.initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter))
622-
.rows(structLiterals);
652+
.rows(expressions);
623653

624654
builder
625655
.commonExtension(optionalAdvancedExtension(rel.getCommon()))

core/src/main/java/io/substrait/relation/RelProtoConverter.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import io.substrait.relation.physical.TargetType;
5757
import io.substrait.type.proto.TypeProtoConverter;
5858
import io.substrait.util.EmptyVisitationContext;
59+
import java.util.ArrayList;
5960
import java.util.Collection;
6061
import java.util.List;
6162
import java.util.stream.Collectors;
@@ -749,17 +750,21 @@ public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeExce
749750
@Override
750751
public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context)
751752
throws RuntimeException {
753+
List<io.substrait.proto.Expression.Nested.Struct> structs = new ArrayList<>();
754+
for (Expression.NestedStruct row : virtualTableScan.getRows()) {
755+
structs.add(
756+
io.substrait.proto.Expression.Nested.Struct.newBuilder()
757+
.addAllFields(
758+
row.fields().stream()
759+
.map(this::toProto)
760+
.collect(java.util.stream.Collectors.toList()))
761+
.build());
762+
}
763+
752764
ReadRel.Builder builder =
753765
ReadRel.newBuilder()
754766
.setCommon(common(virtualTableScan))
755-
.setVirtualTable(
756-
ReadRel.VirtualTable.newBuilder()
757-
.addAllValues(
758-
virtualTableScan.getRows().stream()
759-
.map(this::toProto)
760-
.map(t -> t.getLiteral().getStruct())
761-
.collect(Collectors.toList()))
762-
.build())
767+
.setVirtualTable(ReadRel.VirtualTable.newBuilder().addAllExpressions(structs).build())
763768
.setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter));
764769

765770
virtualTableScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));

core/src/main/java/io/substrait/relation/VirtualTableScan.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import io.substrait.type.TypeVisitor;
66
import io.substrait.util.VisitationContext;
77
import java.util.List;
8+
import java.util.Objects;
89
import org.immutables.value.Value;
910

1011
@Value.Immutable
1112
public abstract class VirtualTableScan extends AbstractReadRel {
1213

13-
public abstract List<Expression.StructLiteral> getRows();
14+
public abstract List<Expression.NestedStruct> getRows();
1415

1516
/**
16-
*
17+
* Checks the following invariants when construction a VirtualTableScan
1718
*
1819
* <ul>
1920
* <li>non-empty rowset
@@ -29,15 +30,28 @@ protected void check() {
2930

3031
assert names.size()
3132
== NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct());
32-
List<Expression.StructLiteral> rows = getRows();
33-
34-
assert rows.size() > 0
35-
&& names.stream().noneMatch(s -> s == null)
36-
&& rows.stream().noneMatch(r -> r == null)
33+
List<Expression.NestedStruct> rows = getRows();
34+
35+
// At the PROTOBUF layer, the Nested.Struct message does not carry nullability information.
36+
// Nullability is attached to the Nested message, which can contain a Nested.Struct.
37+
// The NestedStruct POJO flattens the Nested and Nested.Struct messages together, allowing the
38+
// nullability of a NestedStruct to be set directly.
39+
//
40+
// HOWEVER, the VirtualTable message contains a list of Nested.Struct messages, and as such
41+
// the nullability cannot be set at the protobuf layer. To avoid users attaching meaningless
42+
// nullability information in the POJOs, we restrict the nullability of NestedStructs to false
43+
// when used in VirtualTableScans.
44+
for (Expression.NestedStruct row : rows) {
45+
assert !row.nullable();
46+
}
47+
48+
assert !rows.isEmpty()
49+
&& names.stream().noneMatch(Objects::isNull)
50+
&& rows.stream().noneMatch(Objects::isNull)
3751
&& rows.stream()
3852
.allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size());
3953

40-
for (Expression.StructLiteral row : rows) {
54+
for (Expression.NestedStruct row : rows) {
4155
validateRowConformsToSchema(row);
4256
}
4357
}
@@ -48,10 +62,10 @@ protected void check() {
4862
* @param row the row to validate
4963
* @throws AssertionError if the row does not conform to the schema
5064
*/
51-
private void validateRowConformsToSchema(Expression.StructLiteral row) {
65+
private void validateRowConformsToSchema(Expression.NestedStruct row) {
5266
Type.Struct schemaStruct = getInitialSchema().struct();
5367
List<Type> schemaFieldTypes = schemaStruct.fields();
54-
List<Expression.Literal> rowFields = row.fields();
68+
List<Expression> rowFields = row.fields();
5569

5670
assert rowFields.size() == schemaFieldTypes.size()
5771
: String.format(

0 commit comments

Comments
 (0)