Skip to content

Commit b404a0f

Browse files
fix: added the handling for all the substrait type in cover (#139)
Signed-off-by: Niels Pardon <[email protected]> Co-authored-by: Niels Pardon <[email protected]>
1 parent 8aa2a78 commit b404a0f

File tree

9 files changed

+1282
-795
lines changed

9 files changed

+1282
-795
lines changed

src/substrait/extension_registry/signature_checker_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ def _handle_parameterized_type(
235235
covered.decimal, parameterized_type, ["scale", "precision"], parameters
236236
)
237237

238+
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimeContext):
239+
return kind == "precision_time" and check_integer_type_parameters(
240+
covered.precision_time,
241+
parameterized_type,
242+
["precision"],
243+
parameters,
244+
)
245+
238246
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
239247
return kind == "precision_timestamp" and check_integer_type_parameters(
240248
covered.precision_timestamp,
@@ -251,6 +259,14 @@ def _handle_parameterized_type(
251259
parameters,
252260
)
253261

262+
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionIntervalDayContext):
263+
return kind == "interval_day" and check_integer_type_parameters(
264+
covered.interval_day,
265+
parameterized_type,
266+
["precision"],
267+
parameters,
268+
)
269+
254270
if isinstance(parameterized_type, SubstraitTypeParser.ListContext):
255271
return kind == "list" and covers(
256272
covered.list.type,

tests/extension_registry/__init__.py

Whitespace-only changes.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Tests for parsing of a registry yaml and basic registry operations (lookup, registration)."""
2+
3+
import pytest
4+
import yaml
5+
6+
from substrait.extension_registry import ExtensionRegistry
7+
8+
# Common test YAML content for testing basic functions
9+
CONTENT = """%YAML 1.2
10+
---
11+
urn: extension:test:functions
12+
scalar_functions:
13+
- name: "test_fn"
14+
description: ""
15+
impls:
16+
- args:
17+
- value: i8
18+
variadic:
19+
min: 2
20+
return: i8
21+
- name: "test_fn_variadic_any"
22+
description: ""
23+
impls:
24+
- args:
25+
- value: any1
26+
variadic:
27+
min: 2
28+
return: any1
29+
- name: "add"
30+
description: "Add two values."
31+
impls:
32+
- args:
33+
- name: x
34+
value: i8
35+
- name: y
36+
value: i8
37+
options:
38+
overflow:
39+
values: [ SILENT, SATURATE, ERROR ]
40+
return: i8
41+
- args:
42+
- name: x
43+
value: i8
44+
- name: y
45+
value: i8
46+
- name: z
47+
value: any
48+
options:
49+
overflow:
50+
values: [ SILENT, SATURATE, ERROR ]
51+
return: i16
52+
- args:
53+
- name: x
54+
value: any1
55+
- name: y
56+
value: any1
57+
- name: z
58+
value: any2
59+
options:
60+
overflow:
61+
values: [ SILENT, SATURATE, ERROR ]
62+
return: any2
63+
- name: "test_decimal"
64+
impls:
65+
- args:
66+
- name: x
67+
value: decimal<P1,S1>
68+
- name: y
69+
value: decimal<S1,S2>
70+
return: decimal<P1 + 1,S2 + 1>
71+
- name: "test_enum"
72+
impls:
73+
- args:
74+
- name: op
75+
options: [ INTACT, FLIP ]
76+
- name: x
77+
value: i8
78+
return: i8
79+
- name: "add_declared"
80+
description: "Add two values."
81+
impls:
82+
- args:
83+
- name: x
84+
value: i8
85+
- name: y
86+
value: i8
87+
nullability: DECLARED_OUTPUT
88+
return: i8?
89+
- name: "add_discrete"
90+
description: "Add two values."
91+
impls:
92+
- args:
93+
- name: x
94+
value: i8?
95+
- name: y
96+
value: i8
97+
nullability: DISCRETE
98+
return: i8?
99+
- name: "test_decimal_discrete"
100+
impls:
101+
- args:
102+
- name: x
103+
value: decimal?<P1,S1>
104+
- name: y
105+
value: decimal<S1,S2>
106+
nullability: DISCRETE
107+
return: decimal?<P1 + 1,S2 + 1>
108+
- name: "equal_test"
109+
impls:
110+
- args:
111+
- name: x
112+
value: any
113+
- name: y
114+
value: any
115+
nullability: DISCRETE
116+
return: any
117+
"""
118+
119+
120+
@pytest.fixture(scope="session")
121+
def registry():
122+
"""Create a registry with test functions loaded."""
123+
reg = ExtensionRegistry(load_default_extensions=True)
124+
reg.register_extension_dict(
125+
yaml.safe_load(CONTENT),
126+
uri="https://test.example.com/extension_test_functions.yaml",
127+
)
128+
return reg
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Tests for function types (scalar, aggregate, window)."""
2+
3+
import textwrap
4+
5+
import pytest
6+
import yaml
7+
8+
from substrait.builders.type import i8
9+
from substrait.extension_registry import ExtensionRegistry
10+
11+
12+
@pytest.mark.parametrize(
13+
"test_case",
14+
[
15+
# Scalar functions
16+
pytest.param(
17+
{
18+
"yaml_content": textwrap.dedent("""\
19+
%YAML 1.2
20+
---
21+
urn: extension:test:scalar_funcs
22+
scalar_functions:
23+
- name: "add"
24+
description: "Add two numbers"
25+
impls:
26+
- args:
27+
- value: i8
28+
- value: i8
29+
return: i8
30+
"""),
31+
"urn": "extension:test:scalar_funcs",
32+
"func_name": "add",
33+
"signature": [i8(nullable=False), i8(nullable=False)],
34+
"expected_type": "scalar",
35+
},
36+
id="scalar-add",
37+
),
38+
pytest.param(
39+
{
40+
"yaml_content": textwrap.dedent("""\
41+
%YAML 1.2
42+
---
43+
urn: extension:test:scalar_funcs
44+
scalar_functions:
45+
- name: "test_fn"
46+
description: ""
47+
impls:
48+
- args:
49+
- value: i8
50+
variadic:
51+
min: 2
52+
return: i8
53+
"""),
54+
"urn": "extension:test:scalar_funcs",
55+
"func_name": "test_fn",
56+
"signature": [i8(nullable=False), i8(nullable=False)],
57+
"expected_type": "scalar",
58+
},
59+
id="scalar-test_fn",
60+
),
61+
# Aggregate functions
62+
pytest.param(
63+
{
64+
"yaml_content": textwrap.dedent("""\
65+
%YAML 1.2
66+
---
67+
urn: extension:test:agg_funcs
68+
aggregate_functions:
69+
- name: "count"
70+
description: "Count non-null values"
71+
impls:
72+
- args:
73+
- value: i8
74+
return: i64
75+
"""),
76+
"urn": "extension:test:agg_funcs",
77+
"func_name": "count",
78+
"signature": [i8(nullable=False)],
79+
"expected_type": "aggregate",
80+
},
81+
id="aggregate-count",
82+
),
83+
pytest.param(
84+
{
85+
"yaml_content": textwrap.dedent("""\
86+
%YAML 1.2
87+
---
88+
urn: extension:test:agg_funcs
89+
aggregate_functions:
90+
- name: "sum"
91+
description: "Sum values"
92+
impls:
93+
- args:
94+
- value: i8
95+
return: i64
96+
"""),
97+
"urn": "extension:test:agg_funcs",
98+
"func_name": "sum",
99+
"signature": [i8(nullable=False)],
100+
"expected_type": "aggregate",
101+
},
102+
id="aggregate-sum",
103+
),
104+
# Window functions
105+
pytest.param(
106+
{
107+
"yaml_content": textwrap.dedent("""\
108+
%YAML 1.2
109+
---
110+
urn: extension:test:window_funcs
111+
window_functions:
112+
- name: "row_number"
113+
description: "Assign row numbers"
114+
impls:
115+
- args: []
116+
return: i64
117+
"""),
118+
"urn": "extension:test:window_funcs",
119+
"func_name": "row_number",
120+
"signature": [],
121+
"expected_type": "window",
122+
},
123+
id="window-row_number",
124+
),
125+
pytest.param(
126+
{
127+
"yaml_content": textwrap.dedent("""\
128+
%YAML 1.2
129+
---
130+
urn: extension:test:window_funcs
131+
window_functions:
132+
- name: "rank"
133+
description: "Assign ranks"
134+
impls:
135+
- args: []
136+
return: i64
137+
"""),
138+
"urn": "extension:test:window_funcs",
139+
"func_name": "rank",
140+
"signature": [],
141+
"expected_type": "window",
142+
},
143+
id="window-rank",
144+
),
145+
],
146+
)
147+
def test_all_function_types_from_yaml(test_case):
148+
"""Test that all functions in YAML are registered with correct function_type.value."""
149+
test_registry = ExtensionRegistry(load_default_extensions=False)
150+
test_registry.register_extension_dict(
151+
yaml.safe_load(test_case["yaml_content"]),
152+
uri=f"https://test.example.com/{test_case['urn'].replace(':', '_')}.yaml",
153+
)
154+
155+
result = test_registry.lookup_function(
156+
urn=test_case["urn"],
157+
function_name=test_case["func_name"],
158+
signature=test_case["signature"],
159+
)
160+
assert result is not None, (
161+
f"Failed to lookup {test_case['func_name']} in {test_case['urn']}"
162+
)
163+
entry, _ = result
164+
assert hasattr(entry, "function_type"), (
165+
f"Entry for {test_case['func_name']} missing function_type attribute"
166+
)
167+
assert entry.function_type is not None, (
168+
f"function_type is None for {test_case['func_name']}"
169+
)
170+
assert isinstance(entry.function_type.value, str), (
171+
f"function_type.value is not a string for {test_case['func_name']}"
172+
)
173+
assert entry.function_type.value == test_case["expected_type"], (
174+
f"Expected function_type.value '{test_case['expected_type']}' "
175+
f"for {test_case['func_name']}, got '{entry.function_type.value}'"
176+
)

0 commit comments

Comments
 (0)