Skip to content

Commit bb4de46

Browse files
authored
feat: bindings for simple extensions (#81)
Adds python bindings for json schema of simple extensions. Refactors extension registry to use new binding objects.
1 parent b0fa37f commit bb4de46

File tree

7 files changed

+730
-28
lines changed

7 files changed

+730
-28
lines changed

.devcontainer/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ RUN cd ~ && curl -LO https://github.com/protocolbuffers/protobuf/releases/downlo
99
unzip protoc-25.1-linux-x86_64.zip -d ~/.local && \
1010
rm protoc-25.1-linux-x86_64.zip
1111
RUN curl -sSL "https://github.com/bufbuild/buf/releases/download/v1.50.0/buf-$(uname -s)-$(uname -m)" -o ~/.local/bin/buf && chmod +x ~/.local/bin/buf
12+
RUN curl -LsSf https://astral.sh/uv/0.7.11/install.sh | sh
1213
USER root

Makefile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,16 @@ antlr:
33
&& java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
44
&& rm ../../../src/substrait/gen/antlr/*.tokens \
55
&& rm ../../../src/substrait/gen/antlr/*.interp
6+
7+
codegen-extensions:
8+
uv run --with datamodel-code-generator datamodel-codegen \
9+
--input-file-type jsonschema \
10+
--input third_party/substrait/text/simple_extensions_schema.yaml \
11+
--output src/substrait/gen/json/simple_extensions.py \
12+
--output-model-type dataclasses.dataclass
13+
14+
lint:
15+
uvx ruff@0.11.11 check
16+
17+
format:
18+
uvx ruff@0.11.11 format

src/substrait/extension_registry.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
import yaml
2+
import itertools
13
from substrait.gen.proto.type_pb2 import Type
24
from importlib.resources import files as importlib_files
3-
import itertools
45
from collections import defaultdict
5-
from collections.abc import Mapping
66
from pathlib import Path
7-
from typing import Any, Optional, Union
7+
from typing import Optional, Union
88
from .derivation_expression import evaluate, _evaluate, _parse
9-
10-
import yaml
119
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
10+
from substrait.gen.json import simple_extensions as se
11+
from substrait.simple_extension_utils import build_simple_extensions
12+
1213

1314
DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"
1415

@@ -166,31 +167,35 @@ def covers(
166167

167168
class FunctionEntry:
168169
def __init__(
169-
self, uri: str, name: str, impl: Mapping[str, Any], anchor: int
170+
self, uri: str, name: str, impl: Union[se.Impl, se.Impl1, se.Impl2], anchor: int
170171
) -> None:
171172
self.name = name
173+
self.impl = impl
172174
self.normalized_inputs: list = []
173175
self.uri: str = uri
174176
self.anchor = anchor
175177
self.arguments = []
176-
self.rtn = impl["return"]
177-
self.nullability = impl.get("nullability", "MIRROR")
178-
self.variadic = impl.get("variadic", False)
179-
if input_args := impl.get("args", []):
180-
for val in input_args:
181-
if typ := val.get("value"):
182-
self.arguments.append(_parse(typ))
183-
self.normalized_inputs.append(normalize_substrait_type_names(typ))
184-
elif _ := val.get("name", None):
185-
self.arguments.append(val.get("options"))
178+
self.nullability = (
179+
impl.nullability if impl.nullability else se.NullabilityHandling.MIRROR
180+
)
181+
182+
if impl.args:
183+
for arg in impl.args:
184+
if isinstance(arg, se.ValueArg):
185+
self.arguments.append(_parse(arg.value))
186+
self.normalized_inputs.append(
187+
normalize_substrait_type_names(arg.value)
188+
)
189+
elif isinstance(arg, se.EnumerationArg):
190+
self.arguments.append(arg.options)
186191
self.normalized_inputs.append("req")
187192

188193
def __repr__(self) -> str:
189194
return f"{self.name}:{'_'.join(self.normalized_inputs)}"
190195

191196
def satisfies_signature(self, signature: tuple) -> Optional[str]:
192-
if self.variadic:
193-
min_args_allowed = self.variadic.get("min", 0)
197+
if self.impl.variadic:
198+
min_args_allowed = self.impl.variadic.min or 0
194199
if len(signature) < min_args_allowed:
195200
return None
196201
inputs = [self.arguments[0]] * len(signature)
@@ -209,13 +214,17 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
209214
return None
210215
else:
211216
if not covers(
212-
y, x, parameters, check_nullability=self.nullability == "DISCRETE"
217+
y,
218+
x,
219+
parameters,
220+
check_nullability=self.nullability
221+
== se.NullabilityHandling.DISCRETE,
213222
):
214223
return None
215224

216-
output_type = evaluate(self.rtn, parameters)
225+
output_type = evaluate(self.impl.return_, parameters)
217226

218-
if self.nullability == "MIRROR":
227+
if self.nullability == se.NullabilityHandling.MIRROR:
219228
sig_contains_nullable = any(
220229
[
221230
p.__getattribute__(p.WhichOneof("kind")).nullability
@@ -265,19 +274,27 @@ def register_extension_yaml(
265274
def register_extension_dict(self, definitions: dict, uri: str) -> None:
266275
self._uri_mapping[uri] = next(self._uri_id_generator)
267276

268-
for named_functions in definitions.values():
269-
for function in named_functions:
270-
for impl in function.get("impls", []):
277+
simple_extensions = build_simple_extensions(definitions)
278+
279+
functions = (
280+
(simple_extensions.scalar_functions or [])
281+
+ (simple_extensions.aggregate_functions or [])
282+
+ (simple_extensions.window_functions or [])
283+
)
284+
285+
if functions:
286+
for function in functions:
287+
for impl in function.impls:
271288
func = FunctionEntry(
272-
uri, function["name"], impl, next(self._id_generator)
289+
uri, function.name, impl, next(self._id_generator)
273290
)
274291
if (
275292
func.uri in self._function_mapping
276-
and function["name"] in self._function_mapping[func.uri]
293+
and function.name in self._function_mapping[func.uri]
277294
):
278-
self._function_mapping[func.uri][function["name"]].append(func)
295+
self._function_mapping[func.uri][function.name].append(func)
279296
else:
280-
self._function_mapping[func.uri][function["name"]] = [func]
297+
self._function_mapping[func.uri][function.name] = [func]
281298

282299
# TODO add an optional return type check
283300
def lookup_function(

src/substrait/gen/json/simple_extensions.py

Lines changed: 218 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)