Skip to content
Merged
1 change: 1 addition & 0 deletions src/substrait/builders/type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable

import substrait.gen.proto.type_pb2 as stt


Expand Down
122 changes: 115 additions & 7 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream

from antlr4 import CommonTokenStream, InputStream

from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type
from substrait.gen.proto.type_pb2 import NamedStruct, Type


def _evaluate(x, values: dict):
Expand Down Expand Up @@ -65,22 +67,128 @@ def _evaluate(x, values: dict):
return Type(fp64=Type.FP64(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
return Type(string=Type.String(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
return Type(timestamp=Type.Timestamp(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
return Type(date=Type.Date(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
return Type(interval_year=Type.IntervalYear(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
return Type(uuid=Type.UUID(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
return Type(binary=Type.Binary(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
return Type(time=Type.Time(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
return Type(
decimal=Type.Decimal(
precision=precision, scale=scale, nullability=nullability
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
varchar=Type.VarChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_char=Type.FixedChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_binary=Type.FixedBinary(
length=length,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp=Type.PrecisionTimestamp(
precision=precision,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp_tz=Type.PrecisionTimestampTZ(
precision=precision,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
return Type(
interval_year=Type.IntervalYear(
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
types = list(
map(lambda x: _evaluate(x, values), parametrized_type.expr())
)
return Type(
struct=Type.Struct(
types=types,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
list_type = _evaluate(parametrized_type.expr(), values)
return Type(
list=Type.List(
type=list_type,
nullability=nullability,
)
)

elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
return Type(
map=Type.Map(
key=_evaluate(parametrized_type.key, values),
value=_evaluate(parametrized_type.value, values),
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
names = list(map(lambda k: k.getText(), parametrized_type.Identifier()))
struct = Type.Struct(
types=list(
map(lambda k: _evaluate(k, values), parametrized_type.expr())
),
nullability=nullability,
)
return NamedStruct(
names=names,
struct=struct,
)

raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
elif any_type:
any_var = any_type.AnyVar()
Expand Down
186 changes: 151 additions & 35 deletions src/substrait/extension_registry.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import yaml
import itertools
import re
from substrait.gen.proto.type_pb2 import Type
from importlib.resources import files as importlib_files
from collections import defaultdict
from importlib.resources import files as importlib_files
from pathlib import Path
from typing import Optional, Union
from .derivation_expression import evaluate, _evaluate, _parse

import yaml

from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.json import simple_extensions as se
from substrait.gen.proto.type_pb2 import Type
from substrait.simple_extension_utils import build_simple_extensions
from .bimap import UriUrnBiDiMap

from .bimap import UriUrnBiDiMap
from .derivation_expression import _evaluate, _parse, evaluate

DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"

Expand Down Expand Up @@ -69,20 +71,20 @@ def normalize_substrait_type_names(typ: str) -> str:


def violates_integer_option(actual: int, option, parameters: dict):
option_numeric = None
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
return actual != int(str(option.Number()))
option_numeric = int(str(option.Number()))
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
parameter_name = str(option.Identifier())
if parameter_name in parameters and parameters[parameter_name] != actual:
return True
else:

if parameter_name not in parameters:
parameters[parameter_name] = actual
option_numeric = parameters[parameter_name]
else:
raise Exception(
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
)

return False
return actual != option_numeric


def types_equal(type1: Type, type2: Type, check_nullability=False):
Expand Down Expand Up @@ -112,6 +114,27 @@ def handle_parameter_cover(
return True


def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool:
if not check_nullability:
return True
# The ANTLR context stores a Token called ``isnull`` – it is
# present when the type is declared as nullable.
nullability = (
Type.Nullability.NULLABILITY_NULLABLE
if getattr(parameterized_type, "isnull", None) is not None
else Type.Nullability.NULLABILITY_REQUIRED
)
# if nullability == Type.Nullability.NULLABILITY_NULLABLE:
# return True # is still true even if the covered is required
# The protobuf message stores its own enum – we compare the two.
covered_nullability = getattr(
getattr(covered, kind), # e.g. covered.varchar
"nullability",
None,
)
return nullability == covered_nullability


def covers(
covered: Type,
covering: SubstraitTypeParser.TypeLiteralContext,
Expand All @@ -123,7 +146,6 @@ def covers(
return handle_parameter_cover(
covered, parameter_name, parameters, check_nullability
)

covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()

any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
Expand All @@ -142,33 +164,127 @@ def covers(

parameterized_type = covering.parameterizedType()
if parameterized_type:
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
if covered.WhichOneof("kind") != "decimal":
return False
return _cover_parametrize_type(
covered, parameterized_type, parameters, check_nullability
)

nullability = (
Type.NULLABILITY_NULLABLE
if parameterized_type.isnull
else Type.NULLABILITY_REQUIRED
)

if (
check_nullability
and nullability
!= covered.__getattribute__(covered.WhichOneof("kind")).nullability
):
def check_violates_integer_option_parameters(
covered, parameterized_type, attributes, parameters
):
for attr in attributes:
if not hasattr(covered, attr) and not hasattr(parameterized_type, attr):
return False
covered_attr = getattr(covered, attr)
param_attr = getattr(parameterized_type, attr)
if violates_integer_option(covered_attr, param_attr, parameters):
return True
return False


def _cover_parametrize_type(
covered: Type,
parameterized_type: SubstraitTypeParser.ParameterizedTypeContext,
parameters: dict,
check_nullability=False,
):
kind = covered.WhichOneof("kind")

if not _check_nullability(check_nullability, parameterized_type, covered, kind):
return False

if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext):
if kind != "varchar":
return False
if hasattr(
parameterized_type, "length"
) and check_violates_integer_option_parameters(
covered.varchar, parameterized_type, ["length"], parameters
):
return False
elif isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
if kind != "fixed_char":
return False
if hasattr(
parameterized_type, "length"
) and check_violates_integer_option_parameters(
covered.fixed_char, parameterized_type, ["length"], parameters
):
return False

elif isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
if kind != "fixed_binary":
return False
if hasattr(
parameterized_type, "length"
) and check_violates_integer_option_parameters(
covered.fixed_binary, parameterized_type, ["length"], parameters
):
return False
elif isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
if kind != "decimal":
return False
if not _check_nullability(check_nullability, parameterized_type, covered, kind):
return False
return not check_violates_integer_option_parameters(
covered.decimal, parameterized_type, ["scale", "precision"], parameters
)
elif isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
if kind != "precision_timestamp":
return False
return not check_violates_integer_option_parameters(
covered.precision_timestamp, parameterized_type, ["precision"], parameters
)
elif isinstance(
parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
if kind != "precision_timestamp_tz":
return False
return not check_violates_integer_option_parameters(
covered.precision_timestamp_tz,
parameterized_type,
["precision"],
parameters,
)

elif isinstance(parameterized_type, SubstraitTypeParser.ListContext):
if kind != "list":
return False
covered_element_type = covered.list.type
param_element_ctx = parameterized_type.expr()
return covers(
covered_element_type, param_element_ctx, parameters, check_nullability
)

elif isinstance(parameterized_type, SubstraitTypeParser.MapContext):
if kind != "map":
return False
covered_key_type = covered.map.key
covered_value_type = covered.map.value
param_key_ctx = parameterized_type.key
param_value_ctx = parameterized_type.value
return covers(
covered_key_type, param_key_ctx, parameters, check_nullability
) and covers(covered_value_type, param_value_ctx, parameters, check_nullability)

elif isinstance(parameterized_type, SubstraitTypeParser.StructContext):
if kind != "struct":
return False
covered_types = covered.struct.types
param_types = parameterized_type.expr() or []
if not isinstance(param_types, list):
param_types = [param_types]
if len(covered_types) != len(param_types):
return False
for covered_field, param_field_ctx in zip(covered_types, param_types):
if not covers(
covered_field, param_field_ctx, parameters, check_nullability
): # type: ignore
return False
else:
raise Exception(f"Unhandled type {type(parameterized_type)}")

return not (
violates_integer_option(
covered.decimal.scale, parameterized_type.scale, parameters
)
or violates_integer_option(
covered.decimal.precision, parameterized_type.precision, parameters
)
)
else:
raise Exception(f"Unhandled type {type(parameterized_type)}")
return True


class FunctionEntry:
Expand Down
Loading