Skip to content
Merged
10 changes: 10 additions & 0 deletions src/substrait/builders/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type:
)


def timestamp(nullable=True) -> stt.Type:
return stt.Type(
timestamp=stt.Type.Timestamp(
nullability=stt.Type.NULLABILITY_NULLABLE
if nullable
else stt.Type.NULLABILITY_REQUIRED,
)
)


def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type:
return stt.Type(
struct=stt.Type.Struct(
Expand Down
114 changes: 108 additions & 6 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream

from antlr4 import CommonTokenStream, InputStream

from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type
Expand Down Expand Up @@ -65,22 +67,122 @@ def _evaluate(x, values: dict):
return Type(fp64=Type.FP64(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
return Type(string=Type.String(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
return Type(timestamp=Type.Timestamp(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
return Type(date=Type.Date(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
return Type(interval_year=Type.IntervalYear(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
return Type(uuid=Type.UUID(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
return Type(binary=Type.Binary(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
return Type(time=Type.Time(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
return Type(
decimal=Type.Decimal(
precision=precision, scale=scale, nullability=nullability
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
varchar=Type.VarChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_char=Type.FixedChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_binary=Type.FixedBinary(
length=length,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp=Type.PrecisionTimestamp(
precision=precision,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp_tz=Type.PrecisionTimestampTZ(
precision=precision,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
return Type(
interval_year=Type.IntervalYear(
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
types = list(
map(lambda x: _evaluate(x, values), parametrized_type.expr())
)
return Type(
struct=Type.Struct(
types=types,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
list_type = _evaluate(parametrized_type.expr(), values)
return Type(
list=Type.List(
type=list_type,
nullability=nullability,
)
)

elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
return Type(
map=Type.Map(
key=_evaluate(parametrized_type.key, values),
value=_evaluate(parametrized_type.value, values),
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
# it gives me a parser error i may have to update the parser
# string `evaluate("NSTRUCT<longitude: i32, latitude: i32>")` from the docs https://substrait.io/types/type_classes/
# line 1:17 extraneous input ':'
raise NotImplementedError("Named structure type not implemented yet")
# elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext):

raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
elif any_type:
any_var = any_type.AnyVar()
Expand Down
141 changes: 116 additions & 25 deletions src/substrait/extension_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str:
raise Exception(f"Unrecognized substrait type {typ}")


def violates_integer_option(actual: int, option, parameters: dict):
def violates_integer_option(actual: int, option, parameters: dict, subset=False):
option_numeric = None
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
return actual != int(str(option.Number()))
option_numeric = int(str(option.Number()))
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
parameter_name = str(option.Identifier())
if parameter_name in parameters and parameters[parameter_name] != actual:
return True
else:

if parameter_name not in parameters:
parameters[parameter_name] = actual
option_numeric = parameters[parameter_name]
else:
raise Exception(
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
)

return False
if subset:
return actual < option_numeric
else:
return actual != option_numeric


def types_equal(type1: Type, type2: Type, check_nullability=False):
Expand Down Expand Up @@ -112,6 +115,27 @@ def handle_parameter_cover(
return True


def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool:
if not check_nullability:
return True
# The ANTLR context stores a Token called ``isnull`` – it is
# present when the type is declared as nullable.
nullability = (
Type.Nullability.NULLABILITY_NULLABLE
if getattr(parameterized_type, "isnull", None) is not None
else Type.Nullability.NULLABILITY_REQUIRED
)
# if nullability == Type.Nullability.NULLABILITY_NULLABLE:
# return True # is still true even if the covered is required
# The protobuf message stores its own enum – we compare the two.
covered_nullability = getattr(
getattr(covered, kind), # e.g. covered.varchar
"nullability",
None,
)
return nullability == covered_nullability


def covers(
covered: Type,
covering: SubstraitTypeParser.TypeLiteralContext,
Expand All @@ -123,7 +147,6 @@ def covers(
return handle_parameter_cover(
covered, parameter_name, parameters, check_nullability
)

covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()

any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
Expand All @@ -142,31 +165,99 @@ def covers(

parameterized_type = covering.parameterizedType()
if parameterized_type:
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
if covered.WhichOneof("kind") != "decimal":
kind = covered.WhichOneof("kind")
if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext):
if kind != "varchar":
return False
if hasattr(parameterized_type, "length") and violates_integer_option(
covered.varchar.length, parameterized_type.length, parameters
):
return False

nullability = (
Type.NULLABILITY_NULLABLE
if parameterized_type.isnull
else Type.NULLABILITY_REQUIRED
return _check_nullability(
check_nullability, parameterized_type, covered, kind
)

if (
check_nullability
and nullability
!= covered.__getattribute__(covered.WhichOneof("kind")).nullability
if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
if kind != "fixed_char":
return False
if hasattr(parameterized_type, "length") and violates_integer_option(
covered.fixed_char.length, parameterized_type.length, parameters
):
return False
return _check_nullability(
check_nullability, parameterized_type, covered, kind
)

if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
if kind != "fixed_binary":
return False
if hasattr(parameterized_type, "length") and violates_integer_option(
covered.fixed_binary.length, parameterized_type.length, parameters
):
return False
# return True
return _check_nullability(
check_nullability, parameterized_type, covered, kind
)
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
if kind != "decimal":
return False
if not _check_nullability(
check_nullability, parameterized_type, covered, kind
):
return False
# precision / scale are both optional – a missing value means “no limit”.
covered_scale = getattr(covered.decimal, "scale", 0)
param_scale = getattr(parameterized_type, "scale", 0)
covered_prec = getattr(covered.decimal, "precision", 0)
param_prec = getattr(parameterized_type, "precision", 0)
return not (
violates_integer_option(
covered.decimal.scale, parameterized_type.scale, parameters
)
or violates_integer_option(
covered.decimal.precision, parameterized_type.precision, parameters
)
violates_integer_option(covered_scale, param_scale, parameters)
or violates_integer_option(covered_prec, param_prec, parameters)
)
if isinstance(
parameterized_type, SubstraitTypeParser.PrecisionTimestampContext
):
if kind != "precision_timestamp":
return False
if not _check_nullability(
check_nullability, parameterized_type, covered, kind
):
return False
# return True
covered_prec = getattr(covered.precision_timestamp, "precision", 0)
param_prec = getattr(parameterized_type, "precision", 0)
return not violates_integer_option(covered_prec, param_prec, parameters)

if isinstance(
parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
if kind != "precision_timestamp_tz":
return False
if not _check_nullability(
check_nullability, parameterized_type, covered, kind
):
return False
# return True
covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0)
param_prec = getattr(parameterized_type, "precision", 0)
return not violates_integer_option(covered_prec, param_prec, parameters)

kind_mapping = {
SubstraitTypeParser.ListContext: "list",
SubstraitTypeParser.MapContext: "map",
SubstraitTypeParser.StructContext: "struct",
SubstraitTypeParser.UserDefinedContext: "user_defined",
SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day",
}

for ctx_cls, expected_kind in kind_mapping.items():
if isinstance(parameterized_type, ctx_cls):
if kind != expected_kind:
return False
return _check_nullability(
check_nullability, parameterized_type, covered, kind
)
else:
raise Exception(f"Unhandled type {type(parameterized_type)}")

Expand Down
Loading