Skip to content

Commit 13309df

Browse files
authored
feat(core): handle struct-based UDT literals (#613)
- Introduced `UserDefinedAnyLiteral` and `UserDefinedStructLiteral` to represent user-defined literals - Added `extension_types.yaml` to `DefaultExtensionCatalog` - Added POJOs for type parameter definitions in simple extensions - `Type` now has `typeParameters` method BREAKING CHANGE: `UserDefinedLiteral` immutable abstract class becomes a `UserDefinedAnyLiteral` and `UserDefinedStructLiteral` BREAKING CHANGE: Removed `ExpressionCreator#userDefinedLiteral` method BREAKING CHANGE: Removed all visitor methods operating on `Expression.UserDefinedLiteral`
1 parent b7ff877 commit 13309df

File tree

25 files changed

+889
-68
lines changed

25 files changed

+889
-68
lines changed

core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
151151
return visitFallback(expr, context);
152152
}
153153

154+
@Override
155+
public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E {
156+
return visitFallback(expr, context);
157+
}
158+
159+
@Override
160+
public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E {
161+
return visitFallback(expr, context);
162+
}
163+
154164
@Override
155165
public O visit(Expression.NestedStruct expr, C context) throws E {
156166
return visitFallback(expr, context);

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -693,21 +693,94 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
693693
}
694694
}
695695

696+
/**
697+
* Base interface for user-defined literals.
698+
*
699+
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
700+
*
701+
* <ul>
702+
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAnyLiteral}
703+
* <li>As {@code Literal.Struct} - see {@link UserDefinedStructLiteral}
704+
* </ul>
705+
*/
706+
interface UserDefinedLiteral extends Literal {
707+
String urn();
708+
709+
String name();
710+
711+
List<io.substrait.type.Type.Parameter> typeParameters();
712+
}
713+
714+
/**
715+
* User-defined literal with value encoded as {@link com.google.protobuf.Any}.
716+
*
717+
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
718+
*/
696719
@Value.Immutable
697-
abstract class UserDefinedLiteral implements Literal {
698-
public abstract ByteString value();
720+
abstract class UserDefinedAnyLiteral implements UserDefinedLiteral {
721+
@Override
722+
public abstract String urn();
723+
724+
@Override
725+
public abstract String name();
699726

727+
@Override
728+
public abstract List<io.substrait.type.Type.Parameter> typeParameters();
729+
730+
public abstract com.google.protobuf.Any value();
731+
732+
@Override
733+
public Type.UserDefined getType() {
734+
return Type.UserDefined.builder()
735+
.nullable(nullable())
736+
.urn(urn())
737+
.name(name())
738+
.typeParameters(typeParameters())
739+
.build();
740+
}
741+
742+
public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() {
743+
return ImmutableExpression.UserDefinedAnyLiteral.builder();
744+
}
745+
746+
@Override
747+
public <R, C extends VisitationContext, E extends Throwable> R accept(
748+
ExpressionVisitor<R, C, E> visitor, C context) throws E {
749+
return visitor.visit(this, context);
750+
}
751+
}
752+
753+
/**
754+
* User-defined literal with value encoded as {@link
755+
* io.substrait.proto.Expression.Literal.Struct}.
756+
*
757+
* <p>This encoding uses a structured list of fields to represent the literal value.
758+
*/
759+
@Value.Immutable
760+
abstract class UserDefinedStructLiteral implements UserDefinedLiteral {
761+
@Override
700762
public abstract String urn();
701763

764+
@Override
702765
public abstract String name();
703766

704767
@Override
705-
public Type getType() {
706-
return Type.withNullability(nullable()).userDefined(urn(), name());
768+
public abstract List<io.substrait.type.Type.Parameter> typeParameters();
769+
770+
public abstract List<Literal> fields();
771+
772+
@Override
773+
public Type.UserDefined getType() {
774+
return Type.UserDefined.builder()
775+
.nullable(nullable())
776+
.urn(urn())
777+
.name(name())
778+
.typeParameters(typeParameters())
779+
.build();
707780
}
708781

709-
public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
710-
return ImmutableExpression.UserDefinedLiteral.builder();
782+
public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() {
783+
return ImmutableExpression.UserDefinedStructLiteral.builder();
711784
}
712785

713786
@Override

core/src/main/java/io/substrait/expression/ExpressionCreator.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,51 @@ public static Expression.NestedStruct nestedStruct(boolean nullable, Expression.
306306
return Expression.NestedStruct.builder().nullable(nullable).addFields(fields).build();
307307
}
308308

309-
public static Expression.UserDefinedLiteral userDefinedLiteral(
310-
boolean nullable, String urn, String name, Any value) {
311-
return Expression.UserDefinedLiteral.builder()
309+
/**
310+
* Create a UserDefinedAnyLiteral with google.protobuf.Any representation.
311+
*
312+
* @param nullable whether the literal is nullable
313+
* @param urn the URN of the user-defined type
314+
* @param name the name of the user-defined type
315+
* @param typeParameters the type parameters for the user-defined type (can be an empty list)
316+
* @param value the value, encoded as google.protobuf.Any
317+
*/
318+
public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny(
319+
boolean nullable,
320+
String urn,
321+
String name,
322+
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
323+
Any value) {
324+
return Expression.UserDefinedAnyLiteral.builder()
325+
.nullable(nullable)
326+
.urn(urn)
327+
.name(name)
328+
.addAllTypeParameters(typeParameters)
329+
.value(value)
330+
.build();
331+
}
332+
333+
/**
334+
* Create a UserDefinedStructLiteral with Struct representation.
335+
*
336+
* @param nullable whether the literal is nullable
337+
* @param urn the URN of the user-defined type
338+
* @param name the name of the user-defined type
339+
* @param typeParameters the type parameters for the user-defined type (can be an empty list)
340+
* @param fields the fields, as a list of Literal values
341+
*/
342+
public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct(
343+
boolean nullable,
344+
String urn,
345+
String name,
346+
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
347+
java.util.List<Expression.Literal> fields) {
348+
return Expression.UserDefinedStructLiteral.builder()
312349
.nullable(nullable)
313350
.urn(urn)
314351
.name(name)
315-
.value(value.toByteString())
352+
.addAllTypeParameters(typeParameters)
353+
.addAllFields(fields)
316354
.build();
317355
}
318356

core/src/main/java/io/substrait/expression/ExpressionVisitor.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,24 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
312312
R visit(Expression.NestedStruct expr, C context) throws E;
313313

314314
/**
315-
* Visit a user-defined literal.
315+
* Visit a user-defined any literal.
316316
*
317317
* @param expr the user-defined literal
318318
* @param context visitation context
319319
* @return visit result
320320
* @throws E on visit failure
321321
*/
322-
R visit(Expression.UserDefinedLiteral expr, C context) throws E;
322+
R visit(Expression.UserDefinedAnyLiteral expr, C context) throws E;
323+
324+
/**
325+
* Visit a user-defined struct literal.
326+
*
327+
* @param expr the user-defined literal
328+
* @param context visitation context
329+
* @return visit result
330+
* @throws E on visit failure
331+
*/
332+
R visit(Expression.UserDefinedStructLiteral expr, C context) throws E;
323333

324334
/**
325335
* Visit a switch expression.

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.substrait.expression.proto;
22

3-
import com.google.protobuf.Any;
4-
import com.google.protobuf.InvalidProtocolBufferException;
53
import io.substrait.expression.ExpressionVisitor;
64
import io.substrait.expression.FieldReference;
75
import io.substrait.expression.FunctionArg;
@@ -377,21 +375,51 @@ public Expression visit(
377375

378376
@Override
379377
public Expression visit(
380-
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
378+
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
379+
EmptyVisitationContext context) {
381380
int typeReference =
382381
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
383382
return lit(
384383
bldr -> {
385-
try {
386-
bldr.setNullable(expr.nullable())
387-
.setUserDefined(
388-
Expression.Literal.UserDefined.newBuilder()
389-
.setTypeReference(typeReference)
390-
.setValue(Any.parseFrom(expr.value())))
391-
.build();
392-
} catch (InvalidProtocolBufferException e) {
393-
throw new IllegalStateException(e);
394-
}
384+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
385+
Expression.Literal.UserDefined.newBuilder()
386+
.setTypeReference(typeReference)
387+
.addAllTypeParameters(
388+
expr.typeParameters().stream()
389+
.map(typeProtoConverter::toProto)
390+
.collect(java.util.stream.Collectors.toList()))
391+
.setValue(expr.value());
392+
393+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
394+
});
395+
}
396+
397+
@Override
398+
public Expression visit(
399+
io.substrait.expression.Expression.UserDefinedStructLiteral expr,
400+
EmptyVisitationContext context) {
401+
int typeReference =
402+
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
403+
return lit(
404+
bldr -> {
405+
Expression.Literal.Struct structLiteral =
406+
Expression.Literal.Struct.newBuilder()
407+
.addAllFields(
408+
expr.fields().stream()
409+
.map(this::toLiteral)
410+
.collect(java.util.stream.Collectors.toList()))
411+
.build();
412+
413+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
414+
Expression.Literal.UserDefined.newBuilder()
415+
.setTypeReference(typeReference)
416+
.addAllTypeParameters(
417+
expr.typeParameters().stream()
418+
.map(typeProtoConverter::toProto)
419+
.collect(java.util.stream.Collectors.toList()))
420+
.setStruct(structLiteral);
421+
422+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
395423
});
396424
}
397425

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
506506
{
507507
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
508508
literal.getUserDefined();
509+
509510
SimpleExtension.Type type =
510511
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
511-
return ExpressionCreator.userDefinedLiteral(
512-
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
512+
String urn = type.urn();
513+
String name = type.name();
514+
List<io.substrait.type.Type.Parameter> typeParameters =
515+
userDefinedLiteral.getTypeParametersList().stream()
516+
.map(protoTypeConverter::from)
517+
.collect(Collectors.toList());
518+
519+
switch (userDefinedLiteral.getValCase()) {
520+
case VALUE:
521+
return ExpressionCreator.userDefinedLiteralAny(
522+
literal.getNullable(), urn, name, typeParameters, userDefinedLiteral.getValue());
523+
case STRUCT:
524+
return ExpressionCreator.userDefinedLiteralStruct(
525+
literal.getNullable(),
526+
urn,
527+
name,
528+
typeParameters,
529+
userDefinedLiteral.getStruct().getFieldsList().stream()
530+
.map(this::from)
531+
.collect(Collectors.toList()));
532+
case VAL_NOT_SET:
533+
throw new IllegalStateException(
534+
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
535+
default:
536+
throw new IllegalStateException(
537+
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
538+
}
513539
}
514540
default:
515541
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());

core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public class DefaultExtensionCatalog {
2222
"extension:io.substrait:functions_rounding_decimal";
2323
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
2424
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
25+
public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types";
2526

2627
public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
2728
loadDefaultCollection();
@@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
4445
.map(c -> String.format("/functions_%s.yaml", c))
4546
.collect(Collectors.toList());
4647

48+
defaultFiles.add("/extension_types.yaml");
49+
4750
return SimpleExtension.load(defaultFiles);
4851
}
4952
}

core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,13 @@ public Optional<Expression> visit(Expression.NestedStruct expr, EmptyVisitationC
214214

215215
@Override
216216
public Optional<Expression> visit(
217-
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {
217+
Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E {
218+
return visitLiteral(expr);
219+
}
220+
221+
@Override
222+
public Optional<Expression> visit(
223+
Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E {
218224
return visitLiteral(expr);
219225
}
220226

0 commit comments

Comments
 (0)