Skip to content

Commit 03e38da

Browse files
committed
test: improve tests for UDT roundtrip
1 parent 559ed5b commit 03e38da

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,70 @@
11
package io.substrait.type.proto;
22

3+
import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
36
import com.google.protobuf.Any;
47
import io.substrait.TestBase;
58
import io.substrait.expression.Expression;
69
import io.substrait.expression.ExpressionCreator;
10+
import io.substrait.expression.proto.ExpressionProtoConverter;
11+
import io.substrait.expression.proto.ProtoExpressionConverter;
712
import io.substrait.extension.DefaultExtensionCatalog;
13+
import io.substrait.extension.ExtensionCollector;
14+
import io.substrait.extension.SimpleExtension;
15+
import io.substrait.relation.ProtoRelConverter;
16+
import io.substrait.relation.RelProtoConverter;
817
import java.math.BigDecimal;
18+
import java.util.Collections;
919
import org.junit.jupiter.api.Test;
1020

1121
public class LiteralRoundtripTest extends TestBase {
1222

23+
private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types";
24+
25+
private static final String NESTED_TYPES_YAML =
26+
"---\n"
27+
+ "urn: "
28+
+ NESTED_TYPES_URN
29+
+ "\n"
30+
+ "types:\n"
31+
+ " - name: point\n"
32+
+ " structure:\n"
33+
+ " latitude: i32\n"
34+
+ " longitude: i32\n"
35+
+ " - name: triangle\n"
36+
+ " structure:\n"
37+
+ " p1: point\n"
38+
+ " p2: point\n"
39+
+ " p3: point\n";
40+
41+
private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS =
42+
SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML);
43+
44+
private static final ExtensionCollector NESTED_TYPES_FUNCTION_COLLECTOR =
45+
new ExtensionCollector();
46+
private static final RelProtoConverter NESTED_TYPES_REL_PROTO_CONVERTER =
47+
new RelProtoConverter(NESTED_TYPES_FUNCTION_COLLECTOR);
48+
private static final ProtoRelConverter NESTED_TYPES_PROTO_REL_CONVERTER =
49+
new ProtoRelConverter(NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_EXTENSIONS);
50+
private static final ExpressionProtoConverter NESTED_TYPES_EXPRESSION_TO_PROTO =
51+
new ExpressionProtoConverter(
52+
NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_REL_PROTO_CONVERTER);
53+
private static final ProtoExpressionConverter NESTED_TYPES_PROTO_TO_EXPRESSION =
54+
new ProtoExpressionConverter(
55+
NESTED_TYPES_FUNCTION_COLLECTOR,
56+
NESTED_TYPES_EXTENSIONS,
57+
EMPTY_TYPE,
58+
NESTED_TYPES_PROTO_REL_CONVERTER);
59+
1360
@Test
1461
void decimal() {
1562
io.substrait.expression.Expression.DecimalLiteral val =
1663
ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);
1764
verifyRoundTrip(val);
1865
}
1966

67+
/** Verifies round-trip conversion of a simple user-defined type using Any representation. */
2068
@Test
2169
void userDefinedLiteralWithAnyRepresentation() {
2270
// Create a struct literal inline representing a point with latitude=42, longitude=100
@@ -40,6 +88,7 @@ void userDefinedLiteralWithAnyRepresentation() {
4088
verifyRoundTrip(val);
4189
}
4290

91+
/** Verifies round-trip conversion of a simple user-defined type using Struct representation. */
4392
@Test
4493
void userDefinedLiteralWithStructRepresentation() {
4594
java.util.List<Expression.Literal> fields =
@@ -55,4 +104,167 @@ void userDefinedLiteralWithStructRepresentation() {
55104

56105
verifyRoundTrip(val);
57106
}
107+
108+
/**
109+
* Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three
110+
* point UDTs. Both outer and nested types use Struct representation.
111+
*/
112+
@Test
113+
void nestedUserDefinedLiteralWithStructRepresentation() {
114+
Expression.UserDefinedStruct p1 =
115+
ExpressionCreator.userDefinedLiteralStruct(
116+
false,
117+
NESTED_TYPES_URN,
118+
"point",
119+
Collections.emptyList(),
120+
java.util.Arrays.asList(
121+
ExpressionCreator.i32(false, 0), ExpressionCreator.i32(false, 0)));
122+
123+
Expression.UserDefinedStruct p2 =
124+
ExpressionCreator.userDefinedLiteralStruct(
125+
false,
126+
NESTED_TYPES_URN,
127+
"point",
128+
Collections.emptyList(),
129+
java.util.Arrays.asList(
130+
ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 0)));
131+
132+
Expression.UserDefinedStruct p3 =
133+
ExpressionCreator.userDefinedLiteralStruct(
134+
false,
135+
NESTED_TYPES_URN,
136+
"point",
137+
Collections.emptyList(),
138+
java.util.Arrays.asList(
139+
ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 10)));
140+
141+
Expression.UserDefinedStruct triangle =
142+
ExpressionCreator.userDefinedLiteralStruct(
143+
false,
144+
NESTED_TYPES_URN,
145+
"triangle",
146+
Collections.emptyList(),
147+
java.util.Arrays.asList(p1, p2, p3));
148+
149+
io.substrait.proto.Expression protoExpression =
150+
NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle);
151+
Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression);
152+
assertEquals(triangle, result);
153+
}
154+
155+
/**
156+
* Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three
157+
* point UDTs. Both outer and nested types use Any representation.
158+
*/
159+
@Test
160+
void nestedUserDefinedLiteralWithAnyRepresentation() {
161+
162+
// Create three point UDTs using Any representation
163+
io.substrait.proto.Expression.Literal.Struct p1Struct =
164+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
165+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
166+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
167+
.build();
168+
Any p1Any =
169+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build());
170+
Expression.UserDefinedAny p1 =
171+
ExpressionCreator.userDefinedLiteralAny(
172+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any);
173+
174+
io.substrait.proto.Expression.Literal.Struct p2Struct =
175+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
176+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10))
177+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
178+
.build();
179+
Any p2Any =
180+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build());
181+
Expression.UserDefinedAny p2 =
182+
ExpressionCreator.userDefinedLiteralAny(
183+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any);
184+
185+
io.substrait.proto.Expression.Literal.Struct p3Struct =
186+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
187+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5))
188+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10))
189+
.build();
190+
Any p3Any =
191+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build());
192+
Expression.UserDefinedAny p3 =
193+
ExpressionCreator.userDefinedLiteralAny(
194+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any);
195+
196+
// Create a "triangle" struct containing three point UDTs
197+
io.substrait.proto.Expression.Literal.Struct triangleStruct =
198+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
199+
.addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p1).getLiteral())
200+
.addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p2).getLiteral())
201+
.addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p3).getLiteral())
202+
.build();
203+
Any triangleAny =
204+
Any.pack(
205+
io.substrait.proto.Expression.Literal.newBuilder().setStruct(triangleStruct).build());
206+
207+
Expression.UserDefinedAny triangle =
208+
ExpressionCreator.userDefinedLiteralAny(
209+
false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny);
210+
211+
io.substrait.proto.Expression protoExpression =
212+
NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle);
213+
Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression);
214+
assertEquals(triangle, result);
215+
}
216+
217+
/**
218+
* Verifies round-trip conversion of nested user-defined types with mixed representations. The
219+
* triangle UDT uses Struct representation while the nested point UDTs use Any representation.
220+
*/
221+
@Test
222+
void mixedRepresentationNestedUserDefinedLiteral() {
223+
io.substrait.proto.Expression.Literal.Struct p1Struct =
224+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
225+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
226+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
227+
.build();
228+
Any p1Any =
229+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build());
230+
Expression.UserDefinedAny p1 =
231+
ExpressionCreator.userDefinedLiteralAny(
232+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any);
233+
234+
io.substrait.proto.Expression.Literal.Struct p2Struct =
235+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
236+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10))
237+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0))
238+
.build();
239+
Any p2Any =
240+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build());
241+
Expression.UserDefinedAny p2 =
242+
ExpressionCreator.userDefinedLiteralAny(
243+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any);
244+
245+
io.substrait.proto.Expression.Literal.Struct p3Struct =
246+
io.substrait.proto.Expression.Literal.Struct.newBuilder()
247+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5))
248+
.addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10))
249+
.build();
250+
Any p3Any =
251+
Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build());
252+
Expression.UserDefinedAny p3 =
253+
ExpressionCreator.userDefinedLiteralAny(
254+
false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any);
255+
256+
// Create a "triangle" UDT using Struct representation, but with Any-encoded point fields
257+
Expression.UserDefinedStruct triangle =
258+
ExpressionCreator.userDefinedLiteralStruct(
259+
false,
260+
NESTED_TYPES_URN,
261+
"triangle",
262+
Collections.emptyList(),
263+
java.util.Arrays.asList(p1, p2, p3));
264+
265+
io.substrait.proto.Expression protoExpression =
266+
NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle);
267+
Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression);
268+
assertEquals(triangle, result);
269+
}
58270
}

0 commit comments

Comments
 (0)