Skip to content

Commit 7be8c36

Browse files
committed
feat: graceful migration from URI -> URN
This PR introduces handling of migrating from the usage of URI to URN for extension references. As an intermediate step, both URI and URN and emitted from produced plans. Closes #95
1 parent 8fb11d5 commit 7be8c36

File tree

13 files changed

+1168
-57
lines changed

13 files changed

+1168
-57
lines changed

src/substrait/bimap.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
Bidirectional map for URI <-> URN conversion during the migration period.
3+
4+
This module provides a UriUrnBiDiMap class that maintains a bidirectional mapping
5+
between URIs and URNs, ensuring consistency and detecting conflicts.
6+
7+
NOTE: This file is temporary and can be removed once the URI -> URN migration
8+
is complete across all Substrait implementations. At that point, only URN-based
9+
extension references will be used.
10+
"""
11+
12+
from typing import Optional
13+
14+
15+
class UriUrnBiDiMap:
16+
"""Bidirectional map for URI <-> URN mappings.
17+
18+
Maintains two internal dictionaries to enable O(1) lookups in both directions.
19+
Enforces that each URI maps to exactly one URN and vice versa.
20+
"""
21+
22+
def __init__(self):
23+
self._uri_to_urn: dict[str, str] = {}
24+
self._urn_to_uri: dict[str, str] = {}
25+
26+
def put(self, uri: str, urn: str) -> None:
27+
"""Add a bidirectional URI <-> URN mapping.
28+
29+
Args:
30+
uri: The extension URI (e.g., "https://github.com/.../functions_arithmetic.yaml")
31+
urn: The extension URN (e.g., "extension:io.substrait:functions_arithmetic")
32+
33+
Raises:
34+
ValueError: If the URI or URN already exists with a different mapping
35+
"""
36+
# Check for conflicts
37+
if uri in self._uri_to_urn:
38+
existing_urn = self._uri_to_urn[uri]
39+
if existing_urn != urn:
40+
raise ValueError(
41+
f"URI '{uri}' is already mapped to URN '{existing_urn}', "
42+
f"cannot remap to '{urn}'"
43+
)
44+
# Already have this exact mapping, nothing to do
45+
return
46+
47+
if urn in self._urn_to_uri:
48+
existing_uri = self._urn_to_uri[urn]
49+
if existing_uri != uri:
50+
raise ValueError(
51+
f"URN '{urn}' is already mapped to URI '{existing_uri}', "
52+
f"cannot remap to '{uri}'"
53+
)
54+
# Already have this exact mapping, nothing to do
55+
return
56+
57+
# Add the bidirectional mapping
58+
self._uri_to_urn[uri] = urn
59+
self._urn_to_uri[urn] = uri
60+
61+
def get_urn(self, uri: str) -> Optional[str]:
62+
"""Convert a URI to its corresponding URN.
63+
64+
Args:
65+
uri: The extension URI to look up
66+
67+
Returns:
68+
The corresponding URN, or None if the URI is not in the map
69+
"""
70+
return self._uri_to_urn.get(uri)
71+
72+
def get_uri(self, urn: str) -> Optional[str]:
73+
"""Convert a URN to its corresponding URI.
74+
75+
Args:
76+
urn: The extension URN to look up
77+
78+
Returns:
79+
The corresponding URI, or None if the URN is not in the map
80+
"""
81+
return self._urn_to_uri.get(urn)
82+
83+
def contains_uri(self, uri: str) -> bool:
84+
"""Check if a URI exists in the map.
85+
86+
Args:
87+
uri: The URI to check
88+
89+
Returns:
90+
True if the URI is in the map, False otherwise
91+
"""
92+
return uri in self._uri_to_urn
93+
94+
def contains_urn(self, urn: str) -> bool:
95+
"""Check if a URN exists in the map.
96+
97+
Args:
98+
urn: The URN to check
99+
100+
Returns:
101+
True if the URN is in the map, False otherwise
102+
"""
103+
return urn in self._urn_to_uri
104+
105+
def __len__(self) -> int:
106+
"""Return the number of mappings in the bimap."""
107+
return len(self._uri_to_urn)
108+
109+
def __repr__(self) -> str:
110+
"""Return a string representation of the bimap."""
111+
return f"UriUrnBiDiMap({len(self)} mappings)"

src/substrait/builders/extended_expression.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from substrait.utils import (
99
type_num_names,
1010
merge_extension_urns,
11+
merge_extension_uris,
1112
merge_extension_declarations,
1213
)
1314
from substrait.type_inference import infer_extended_expression_schema
@@ -229,26 +230,44 @@ def resolve(
229230
if not func:
230231
raise Exception(f"Unknown function {function} for {signature}")
231232

233+
# Create URN extension
232234
func_extension_urns = [
233235
ste.SimpleExtensionURN(
234236
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
235237
)
236238
]
237239

240+
# Create URI extension (convert URN to URI via bimap)
241+
uri = registry.urn_to_uri(urn)
242+
func_extension_uris = []
243+
if uri:
244+
uri_anchor = registry.lookup_uri_anchor(uri)
245+
if uri_anchor:
246+
func_extension_uris = [
247+
ste.SimpleExtensionURI(extension_uri_anchor=uri_anchor, uri=uri)
248+
]
249+
250+
# Create extension function declaration with both URI and URN references
238251
func_extensions = [
239252
ste.SimpleExtensionDeclaration(
240253
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
241254
extension_urn_reference=registry.lookup_urn(urn),
255+
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
242256
function_anchor=func[0].anchor,
243257
name=str(func[0]),
244258
)
245259
)
246260
]
247261

262+
# Merge extensions from all expressions
248263
extension_urns = merge_extension_urns(
249264
func_extension_urns, *[b.extension_urns for b in bound_expressions]
250265
)
251266

267+
extension_uris = merge_extension_uris(
268+
func_extension_uris, *[b.extension_uris for b in bound_expressions]
269+
)
270+
252271
extensions = merge_extension_declarations(
253272
func_extensions, *[b.extensions for b in bound_expressions]
254273
)
@@ -277,6 +296,7 @@ def resolve(
277296
],
278297
base_schema=base_schema,
279298
extension_urns=extension_urns,
299+
extension_uris=extension_uris,
280300
extensions=extensions,
281301
)
282302

@@ -309,26 +329,44 @@ def resolve(
309329
if not func:
310330
raise Exception(f"Unknown function {function} for {signature}")
311331

332+
# Create URN extension
312333
func_extension_urns = [
313334
ste.SimpleExtensionURN(
314335
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
315336
)
316337
]
317338

339+
# Create URI extension (convert URN to URI via bimap)
340+
uri = registry.urn_to_uri(urn)
341+
func_extension_uris = []
342+
if uri:
343+
uri_anchor = registry.lookup_uri_anchor(uri)
344+
if uri_anchor:
345+
func_extension_uris = [
346+
ste.SimpleExtensionURI(extension_uri_anchor=uri_anchor, uri=uri)
347+
]
348+
349+
# Create extension function declaration with both URI and URN references
318350
func_extensions = [
319351
ste.SimpleExtensionDeclaration(
320352
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
321353
extension_urn_reference=registry.lookup_urn(urn),
354+
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
322355
function_anchor=func[0].anchor,
323356
name=str(func[0]),
324357
)
325358
)
326359
]
327360

361+
# Merge extensions from all expressions
328362
extension_urns = merge_extension_urns(
329363
func_extension_urns, *[b.extension_urns for b in bound_expressions]
330364
)
331365

366+
extension_uris = merge_extension_uris(
367+
func_extension_uris, *[b.extension_uris for b in bound_expressions]
368+
)
369+
332370
extensions = merge_extension_declarations(
333371
func_extensions, *[b.extensions for b in bound_expressions]
334372
)
@@ -353,6 +391,7 @@ def resolve(
353391
],
354392
base_schema=base_schema,
355393
extension_urns=extension_urns,
394+
extension_uris=extension_uris,
356395
extensions=extensions,
357396
)
358397

@@ -391,28 +430,48 @@ def resolve(
391430
if not func:
392431
raise Exception(f"Unknown function {function} for {signature}")
393432

433+
# Create URN extension
394434
func_extension_urns = [
395435
ste.SimpleExtensionURN(
396436
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
397437
)
398438
]
399439

440+
# Create URI extension (convert URN to URI via bimap)
441+
uri = registry.urn_to_uri(urn)
442+
func_extension_uris = []
443+
if uri:
444+
uri_anchor = registry.lookup_uri_anchor(uri)
445+
if uri_anchor:
446+
func_extension_uris = [
447+
ste.SimpleExtensionURI(extension_uri_anchor=uri_anchor, uri=uri)
448+
]
449+
450+
# Create extension function declaration with both URI and URN references
400451
func_extensions = [
401452
ste.SimpleExtensionDeclaration(
402453
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
403454
extension_urn_reference=registry.lookup_urn(urn),
455+
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
404456
function_anchor=func[0].anchor,
405457
name=str(func[0]),
406458
)
407459
)
408460
]
409461

462+
# Merge extensions from all expressions
410463
extension_urns = merge_extension_urns(
411464
func_extension_urns,
412465
*[b.extension_urns for b in bound_expressions],
413466
*[b.extension_urns for b in bound_partitions],
414467
)
415468

469+
extension_uris = merge_extension_uris(
470+
func_extension_uris,
471+
*[b.extension_uris for b in bound_expressions],
472+
*[b.extension_uris for b in bound_partitions],
473+
)
474+
416475
extensions = merge_extension_declarations(
417476
func_extensions,
418477
*[b.extensions for b in bound_expressions],
@@ -446,6 +505,7 @@ def resolve(
446505
],
447506
base_schema=base_schema,
448507
extension_urns=extension_urns,
508+
extension_uris=extension_uris,
449509
extensions=extensions,
450510
)
451511

src/substrait/builders/plan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
resolve_expression,
1919
)
2020
from substrait.type_inference import infer_plan_schema
21-
from substrait.utils import merge_extension_declarations, merge_extension_urns
21+
from substrait.utils import merge_extension_declarations, merge_extension_urns, merge_extension_uris
2222

2323
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2424

2525
PlanOrUnbound = Union[stp.Plan, UnboundPlan]
2626

2727

2828
def _merge_extensions(*objs):
29+
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
30+
31+
During the URI/URN migration period, we maintain both URI and URN references
32+
for maximum compatibility.
33+
"""
2934
return {
35+
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),
3036
"extension_urns": merge_extension_urns(*[b.extension_urns for b in objs if b]),
3137
"extensions": merge_extension_declarations(*[b.extensions for b in objs if b]),
3238
}

0 commit comments

Comments
 (0)