|
| 1 | +import substrait.gen.proto.type_pb2 as stt |
| 2 | +import substrait.gen.proto.plan_pb2 as stp |
| 3 | +import substrait.gen.proto.algebra_pb2 as stalg |
| 4 | +from substrait.builders.type import boolean, i64 |
| 5 | +from substrait.builders.plan import read_named_table |
| 6 | +from substrait.extension_registry import ExtensionRegistry |
| 7 | +import substrait.dataframe as sdf |
| 8 | + |
| 9 | + |
| 10 | +registry = ExtensionRegistry(load_default_extensions=False) |
| 11 | + |
| 12 | +struct = stt.Type.Struct( |
| 13 | + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED |
| 14 | +) |
| 15 | + |
| 16 | +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) |
| 17 | + |
| 18 | + |
| 19 | +def test_project(): |
| 20 | + df = sdf.DataFrame(read_named_table("table", named_struct)) |
| 21 | + |
| 22 | + actual = df.select(id=sdf.col("id")).to_substrait(registry) |
| 23 | + |
| 24 | + expected = stp.Plan( |
| 25 | + relations=[ |
| 26 | + stp.PlanRel( |
| 27 | + root=stalg.RelRoot( |
| 28 | + input=stalg.Rel( |
| 29 | + project=stalg.ProjectRel( |
| 30 | + common=stalg.RelCommon( |
| 31 | + emit=stalg.RelCommon.Emit(output_mapping=[2]) |
| 32 | + ), |
| 33 | + input=df.to_substrait(None).relations[-1].root.input, |
| 34 | + expressions=[ |
| 35 | + stalg.Expression( |
| 36 | + selection=stalg.Expression.FieldReference( |
| 37 | + direct_reference=stalg.Expression.ReferenceSegment( |
| 38 | + struct_field=stalg.Expression.ReferenceSegment.StructField( |
| 39 | + field=0 |
| 40 | + ) |
| 41 | + ), |
| 42 | + root_reference=stalg.Expression.FieldReference.RootReference(), |
| 43 | + ) |
| 44 | + ) |
| 45 | + ], |
| 46 | + ) |
| 47 | + ), |
| 48 | + names=["id"], |
| 49 | + ) |
| 50 | + ) |
| 51 | + ] |
| 52 | + ) |
| 53 | + |
| 54 | + assert actual == expected |
0 commit comments