11package io .substrait .type .proto ;
22
3+ import static io .substrait .expression .proto .ProtoExpressionConverter .EMPTY_TYPE ;
4+ import static org .junit .jupiter .api .Assertions .assertEquals ;
5+
36import com .google .protobuf .Any ;
47import io .substrait .TestBase ;
58import io .substrait .expression .Expression ;
69import io .substrait .expression .ExpressionCreator ;
10+ import io .substrait .expression .proto .ExpressionProtoConverter ;
11+ import io .substrait .expression .proto .ProtoExpressionConverter ;
712import io .substrait .extension .DefaultExtensionCatalog ;
13+ import io .substrait .extension .ExtensionCollector ;
14+ import io .substrait .extension .SimpleExtension ;
15+ import io .substrait .relation .ProtoRelConverter ;
16+ import io .substrait .relation .RelProtoConverter ;
817import java .math .BigDecimal ;
18+ import java .util .Collections ;
919import org .junit .jupiter .api .Test ;
1020
1121public class LiteralRoundtripTest extends TestBase {
1222
23+ private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types" ;
24+
25+ private static final String NESTED_TYPES_YAML =
26+ "---\n "
27+ + "urn: "
28+ + NESTED_TYPES_URN
29+ + "\n "
30+ + "types:\n "
31+ + " - name: point\n "
32+ + " structure:\n "
33+ + " latitude: i32\n "
34+ + " longitude: i32\n "
35+ + " - name: triangle\n "
36+ + " structure:\n "
37+ + " p1: point\n "
38+ + " p2: point\n "
39+ + " p3: point\n " ;
40+
41+ private static final SimpleExtension .ExtensionCollection NESTED_TYPES_EXTENSIONS =
42+ SimpleExtension .load ("nested_types.yaml" , NESTED_TYPES_YAML );
43+
44+ private static final ExtensionCollector NESTED_TYPES_FUNCTION_COLLECTOR =
45+ new ExtensionCollector ();
46+ private static final RelProtoConverter NESTED_TYPES_REL_PROTO_CONVERTER =
47+ new RelProtoConverter (NESTED_TYPES_FUNCTION_COLLECTOR );
48+ private static final ProtoRelConverter NESTED_TYPES_PROTO_REL_CONVERTER =
49+ new ProtoRelConverter (NESTED_TYPES_FUNCTION_COLLECTOR , NESTED_TYPES_EXTENSIONS );
50+ private static final ExpressionProtoConverter NESTED_TYPES_EXPRESSION_TO_PROTO =
51+ new ExpressionProtoConverter (
52+ NESTED_TYPES_FUNCTION_COLLECTOR , NESTED_TYPES_REL_PROTO_CONVERTER );
53+ private static final ProtoExpressionConverter NESTED_TYPES_PROTO_TO_EXPRESSION =
54+ new ProtoExpressionConverter (
55+ NESTED_TYPES_FUNCTION_COLLECTOR ,
56+ NESTED_TYPES_EXTENSIONS ,
57+ EMPTY_TYPE ,
58+ NESTED_TYPES_PROTO_REL_CONVERTER );
59+
1360 @ Test
1461 void decimal () {
1562 io .substrait .expression .Expression .DecimalLiteral val =
1663 ExpressionCreator .decimal (false , BigDecimal .TEN , 10 , 2 );
1764 verifyRoundTrip (val );
1865 }
1966
67+ /** Verifies round-trip conversion of a simple user-defined type using Any representation. */
2068 @ Test
2169 void userDefinedLiteralWithAnyRepresentation () {
2270 // Create a struct literal inline representing a point with latitude=42, longitude=100
@@ -40,6 +88,7 @@ void userDefinedLiteralWithAnyRepresentation() {
4088 verifyRoundTrip (val );
4189 }
4290
91+ /** Verifies round-trip conversion of a simple user-defined type using Struct representation. */
4392 @ Test
4493 void userDefinedLiteralWithStructRepresentation () {
4594 java .util .List <Expression .Literal > fields =
@@ -55,4 +104,167 @@ void userDefinedLiteralWithStructRepresentation() {
55104
56105 verifyRoundTrip (val );
57106 }
107+
108+ /**
109+ * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three
110+ * point UDTs. Both outer and nested types use Struct representation.
111+ */
112+ @ Test
113+ void nestedUserDefinedLiteralWithStructRepresentation () {
114+ Expression .UserDefinedStruct p1 =
115+ ExpressionCreator .userDefinedLiteralStruct (
116+ false ,
117+ NESTED_TYPES_URN ,
118+ "point" ,
119+ Collections .emptyList (),
120+ java .util .Arrays .asList (
121+ ExpressionCreator .i32 (false , 0 ), ExpressionCreator .i32 (false , 0 )));
122+
123+ Expression .UserDefinedStruct p2 =
124+ ExpressionCreator .userDefinedLiteralStruct (
125+ false ,
126+ NESTED_TYPES_URN ,
127+ "point" ,
128+ Collections .emptyList (),
129+ java .util .Arrays .asList (
130+ ExpressionCreator .i32 (false , 10 ), ExpressionCreator .i32 (false , 0 )));
131+
132+ Expression .UserDefinedStruct p3 =
133+ ExpressionCreator .userDefinedLiteralStruct (
134+ false ,
135+ NESTED_TYPES_URN ,
136+ "point" ,
137+ Collections .emptyList (),
138+ java .util .Arrays .asList (
139+ ExpressionCreator .i32 (false , 5 ), ExpressionCreator .i32 (false , 10 )));
140+
141+ Expression .UserDefinedStruct triangle =
142+ ExpressionCreator .userDefinedLiteralStruct (
143+ false ,
144+ NESTED_TYPES_URN ,
145+ "triangle" ,
146+ Collections .emptyList (),
147+ java .util .Arrays .asList (p1 , p2 , p3 ));
148+
149+ io .substrait .proto .Expression protoExpression =
150+ NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (triangle );
151+ Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION .from (protoExpression );
152+ assertEquals (triangle , result );
153+ }
154+
155+ /**
156+ * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three
157+ * point UDTs. Both outer and nested types use Any representation.
158+ */
159+ @ Test
160+ void nestedUserDefinedLiteralWithAnyRepresentation () {
161+
162+ // Create three point UDTs using Any representation
163+ io .substrait .proto .Expression .Literal .Struct p1Struct =
164+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
165+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
166+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
167+ .build ();
168+ Any p1Any =
169+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p1Struct ).build ());
170+ Expression .UserDefinedAny p1 =
171+ ExpressionCreator .userDefinedLiteralAny (
172+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p1Any );
173+
174+ io .substrait .proto .Expression .Literal .Struct p2Struct =
175+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
176+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (10 ))
177+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
178+ .build ();
179+ Any p2Any =
180+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p2Struct ).build ());
181+ Expression .UserDefinedAny p2 =
182+ ExpressionCreator .userDefinedLiteralAny (
183+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p2Any );
184+
185+ io .substrait .proto .Expression .Literal .Struct p3Struct =
186+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
187+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (5 ))
188+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (10 ))
189+ .build ();
190+ Any p3Any =
191+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p3Struct ).build ());
192+ Expression .UserDefinedAny p3 =
193+ ExpressionCreator .userDefinedLiteralAny (
194+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p3Any );
195+
196+ // Create a "triangle" struct containing three point UDTs
197+ io .substrait .proto .Expression .Literal .Struct triangleStruct =
198+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
199+ .addFields (NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (p1 ).getLiteral ())
200+ .addFields (NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (p2 ).getLiteral ())
201+ .addFields (NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (p3 ).getLiteral ())
202+ .build ();
203+ Any triangleAny =
204+ Any .pack (
205+ io .substrait .proto .Expression .Literal .newBuilder ().setStruct (triangleStruct ).build ());
206+
207+ Expression .UserDefinedAny triangle =
208+ ExpressionCreator .userDefinedLiteralAny (
209+ false , NESTED_TYPES_URN , "triangle" , Collections .emptyList (), triangleAny );
210+
211+ io .substrait .proto .Expression protoExpression =
212+ NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (triangle );
213+ Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION .from (protoExpression );
214+ assertEquals (triangle , result );
215+ }
216+
217+ /**
218+ * Verifies round-trip conversion of nested user-defined types with mixed representations. The
219+ * triangle UDT uses Struct representation while the nested point UDTs use Any representation.
220+ */
221+ @ Test
222+ void mixedRepresentationNestedUserDefinedLiteral () {
223+ io .substrait .proto .Expression .Literal .Struct p1Struct =
224+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
225+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
226+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
227+ .build ();
228+ Any p1Any =
229+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p1Struct ).build ());
230+ Expression .UserDefinedAny p1 =
231+ ExpressionCreator .userDefinedLiteralAny (
232+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p1Any );
233+
234+ io .substrait .proto .Expression .Literal .Struct p2Struct =
235+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
236+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (10 ))
237+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (0 ))
238+ .build ();
239+ Any p2Any =
240+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p2Struct ).build ());
241+ Expression .UserDefinedAny p2 =
242+ ExpressionCreator .userDefinedLiteralAny (
243+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p2Any );
244+
245+ io .substrait .proto .Expression .Literal .Struct p3Struct =
246+ io .substrait .proto .Expression .Literal .Struct .newBuilder ()
247+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (5 ))
248+ .addFields (io .substrait .proto .Expression .Literal .newBuilder ().setI32 (10 ))
249+ .build ();
250+ Any p3Any =
251+ Any .pack (io .substrait .proto .Expression .Literal .newBuilder ().setStruct (p3Struct ).build ());
252+ Expression .UserDefinedAny p3 =
253+ ExpressionCreator .userDefinedLiteralAny (
254+ false , NESTED_TYPES_URN , "point" , Collections .emptyList (), p3Any );
255+
256+ // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields
257+ Expression .UserDefinedStruct triangle =
258+ ExpressionCreator .userDefinedLiteralStruct (
259+ false ,
260+ NESTED_TYPES_URN ,
261+ "triangle" ,
262+ Collections .emptyList (),
263+ java .util .Arrays .asList (p1 , p2 , p3 ));
264+
265+ io .substrait .proto .Expression protoExpression =
266+ NESTED_TYPES_EXPRESSION_TO_PROTO .toProto (triangle );
267+ Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION .from (protoExpression );
268+ assertEquals (triangle , result );
269+ }
58270}
0 commit comments