Skip to content

Commit c453d69

Browse files
authored
feat: add literal builders (#77)
adds literal builders for binary, char and date types
1 parent acfd62a commit c453d69

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

src/substrait/builders/extended_expression.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import date
12
import itertools
23
import substrait.gen.proto.algebra_pb2 as stalg
34
import substrait.gen.proto.type_pb2 as stp
@@ -41,6 +42,36 @@ def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.E
4142
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
4243
elif kind == "string":
4344
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
45+
elif kind == "binary":
46+
literal = stalg.Expression.Literal(binary=value, nullable=type.binary.nullability == stp.Type.NULLABILITY_NULLABLE)
47+
elif kind == "date":
48+
date_value = (value - date(1970,1,1)).days if isinstance(value, date) else value
49+
literal = stalg.Expression.Literal(date=date_value, nullable=type.date.nullability == stp.Type.NULLABILITY_NULLABLE)
50+
# TODO
51+
# IntervalYearToMonth interval_year_to_month = 19;
52+
# IntervalDayToSecond interval_day_to_second = 20;
53+
# IntervalCompound interval_compound = 36;
54+
elif kind == "fixed_char":
55+
literal = stalg.Expression.Literal(fixed_char=value, nullable=type.fixed_char.nullability == stp.Type.NULLABILITY_NULLABLE)
56+
elif kind == "varchar":
57+
literal = stalg.Expression.Literal(
58+
var_char=stalg.Expression.Literal.VarChar(value=value, length=type.varchar.length),
59+
nullable=type.varchar.nullability == stp.Type.NULLABILITY_NULLABLE
60+
)
61+
elif kind == "fixed_binary":
62+
literal = stalg.Expression.Literal(fixed_binary=value, nullable=type.fixed_binary.nullability == stp.Type.NULLABILITY_NULLABLE)
63+
# TODO
64+
# Decimal decimal = 24;
65+
# PrecisionTime precision_time = 37; // Time in precision units past midnight.
66+
# PrecisionTimestamp precision_timestamp = 34;
67+
# PrecisionTimestamp precision_timestamp_tz = 35;
68+
# Struct struct = 25;
69+
# Map map = 26;
70+
# bytes uuid = 28;
71+
# Type null = 29; // a typed null literal
72+
# List list = 30;
73+
# Type.List empty_list = 31;
74+
# Type.Map empty_map = 32;
4475
else:
4576
raise Exception(f"Unknown literal type - {type}")
4677

src/substrait/builders/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def fixed_char(length: int, nullable=True) -> stt.Type:
4747
return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED))
4848

4949
def var_char(length: int, nullable=True) -> stt.Type:
50-
return stt.Type(var_char=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED))
50+
return stt.Type(varchar=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED))
5151

5252
def fixed_binary(length: int, nullable=True) -> stt.Type:
5353
return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED))
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from datetime import date
2+
import substrait.gen.proto.algebra_pb2 as stalg
3+
import substrait.gen.proto.type_pb2 as stt
4+
import substrait.gen.proto.extended_expression_pb2 as stee
5+
from substrait.builders.extended_expression import literal
6+
from substrait.builders import type as sttb
7+
8+
def extract_literal(builder):
9+
return builder(None, None).referred_expr[0].expression.literal
10+
11+
def test_boolean():
12+
assert extract_literal(literal(True, sttb.boolean())) == stalg.Expression.Literal(boolean=True, nullable=True)
13+
assert extract_literal(literal(False, sttb.boolean())) == stalg.Expression.Literal(boolean=False, nullable=True)
14+
15+
def test_integer():
16+
assert extract_literal(literal(100, sttb.i16())) == stalg.Expression.Literal(i16=100, nullable=True)
17+
18+
def test_string():
19+
assert extract_literal(literal("Hello", sttb.string())) == stalg.Expression.Literal(string="Hello", nullable=True)
20+
21+
def test_binary():
22+
assert extract_literal(literal(b"Hello", sttb.binary())) == stalg.Expression.Literal(binary=b"Hello", nullable=True)
23+
24+
def test_date():
25+
assert extract_literal(literal(1000, sttb.date())) == stalg.Expression.Literal(date=1000, nullable=True)
26+
assert extract_literal(literal(date(1970, 1, 11), sttb.date())) == stalg.Expression.Literal(date=10, nullable=True)
27+
28+
def test_fixed_char():
29+
assert extract_literal(literal("Hello", sttb.fixed_char(length=5))) == stalg.Expression.Literal(fixed_char="Hello", nullable=True)
30+
31+
def test_var_char():
32+
assert extract_literal(literal("Hello", sttb.var_char(length=5))) \
33+
== stalg.Expression.Literal(var_char=stalg.Expression.Literal.VarChar(value="Hello", length=5), nullable=True)
34+
35+
def test_fixed_binary():
36+
assert extract_literal(literal(b"Hello", sttb.fixed_binary(length=5))) == stalg.Expression.Literal(fixed_binary=b"Hello", nullable=True)

0 commit comments

Comments
 (0)