Skip to content

Commit c17878a

Browse files
authored
Feat: add ability to create dialect plugins (#6627)
1 parent b75a3e3 commit c17878a

File tree

3 files changed

+143
-5
lines changed

3 files changed

+143
-5
lines changed

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,29 @@ x + interval '1' month
588588
**Official Dialects** are maintained by the core SQLGlot team with higher priority for bug fixes and feature additions.
589589

590590
**Community Dialects** are developed and maintained primarily through community contributions. These are fully functional but may receive lower priority for issue resolution compared to officially supported dialects. We welcome and encourage community contributions to improve these dialects.
591+
592+
### Creating a Dialect Plugin
593+
594+
If your database isn't supported, you can create a plugin that registers a custom dialect via entry points. Create a package with your dialect class and register it in `setup.py`:
595+
596+
```python
597+
from setuptools import setup
598+
599+
setup(
600+
name="mydb-sqlglot-dialect",
601+
entry_points={
602+
"sqlglot.dialects": [
603+
"mydb = my_package.dialect:MyDB",
604+
],
605+
},
606+
)
607+
```
608+
609+
The dialect will be automatically discovered and can be used like any built-in dialect:
610+
611+
```python
612+
from sqlglot import transpile
613+
transpile("SELECT * FROM t", read="mydb", write="postgres")
614+
```
615+
616+
See the [Custom Dialects](#custom-dialects) section for implementation details.

sqlglot/dialects/dialect.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from sqlglot.trie import new_trie
2828
from sqlglot.typing import EXPRESSION_METADATA
2929

30+
from importlib.metadata import entry_points
31+
3032
DATE_ADD_OR_DIFF = t.Union[
3133
exp.DateAdd,
3234
exp.DateDiff,
@@ -66,6 +68,8 @@
6668
"\\\\": "\\",
6769
}
6870

71+
PLUGIN_GROUP_NAME = "sqlglot.dialects"
72+
6973

7074
class Dialects(str, Enum):
7175
"""Dialects supported by SQLGLot."""
@@ -153,12 +157,54 @@ def _try_load(cls, key: str | Dialects) -> None:
153157
if isinstance(key, Dialects):
154158
key = key.value
155159

156-
# This import will lead to a new dialect being loaded, and hence, registered.
157-
# We check that the key is an actual sqlglot module to avoid blindly importing
158-
# files. Custom user dialects need to be imported at the top-level package, in
159-
# order for them to be registered as soon as possible.
160+
# 1. Try standard sqlglot modules first
160161
if key in DIALECT_MODULE_NAMES:
162+
module = importlib.import_module(f"sqlglot.dialects.{key}")
163+
# If module was already imported, the class may not be in _classes
164+
# Find and register the dialect class from the module
165+
if key not in cls._classes:
166+
for attr_name in dir(module):
167+
attr = getattr(module, attr_name, None)
168+
if (
169+
isinstance(attr, type)
170+
and issubclass(attr, Dialect)
171+
and attr.__name__.lower() == key
172+
):
173+
cls._classes[key] = attr
174+
break
175+
return
176+
177+
# 2. Try entry points (for plugins)
178+
try:
179+
all_eps = entry_points()
180+
# Python 3.10+ has select() method, older versions use dict-like access
181+
if hasattr(all_eps, "select"):
182+
eps = all_eps.select(group=PLUGIN_GROUP_NAME, name=key)
183+
else:
184+
# For older Python versions, entry_points() returns a dict-like object
185+
group_eps = all_eps.get(PLUGIN_GROUP_NAME, []) # type: ignore
186+
eps = [ep for ep in group_eps if ep.name == key] # type: ignore
187+
188+
for entry_point in eps:
189+
dialect_class = entry_point.load()
190+
# Verify it's a Dialect subclass
191+
# issubclass() returns False if not a subclass, TypeError only if not a class at all
192+
if isinstance(dialect_class, type) and issubclass(dialect_class, Dialect):
193+
# Register the dialect using the entry point name (key)
194+
# The metaclass may have registered it by class name, but we need it by entry point name
195+
if key not in cls._classes:
196+
cls._classes[key] = dialect_class
197+
return
198+
except ImportError:
199+
# entry_point.load() failed (bad plugin - module/class doesn't exist)
200+
pass
201+
202+
# 3. Try direct import (for backward compatibility)
203+
# This allows namespace packages or explicit imports to work
204+
try:
161205
importlib.import_module(f"sqlglot.dialects.{key}")
206+
except ImportError:
207+
pass
162208

163209
@classmethod
164210
def __getitem__(cls, key: str) -> t.Type[Dialect]:
@@ -947,7 +993,9 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect:
947993

948994
result = cls.get(dialect_name.strip())
949995
if not result:
950-
suggest_closest_match_and_fail("dialect", dialect_name, list(DIALECT_MODULE_NAMES))
996+
# Include both built-in dialects and any loaded dialects for better error messages
997+
all_dialects = set(DIALECT_MODULE_NAMES) | set(cls._classes.keys())
998+
suggest_closest_match_and_fail("dialect", dialect_name, all_dialects)
951999

9521000
assert result is not None
9531001
return result(**kwargs)

tests/test_dialect_entry_points.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import unittest
2+
from unittest.mock import Mock, patch
3+
4+
from sqlglot import Dialect
5+
from sqlglot.dialects.dialect import Dialect as DialectBase
6+
7+
8+
class FakeDialect(DialectBase):
9+
pass
10+
11+
12+
class TestDialectEntryPoints(unittest.TestCase):
13+
def setUp(self):
14+
Dialect._classes.clear()
15+
16+
def tearDown(self):
17+
Dialect._classes.clear()
18+
19+
def test_entry_point_plugin_discovery_modern_api(self):
20+
fake_entry_point = Mock()
21+
fake_entry_point.name = "fakedb"
22+
fake_entry_point.load.return_value = FakeDialect
23+
24+
mock_selectable = Mock()
25+
mock_selectable.select.return_value = [fake_entry_point]
26+
27+
mock_entry_points = Mock(return_value=mock_selectable)
28+
29+
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
30+
dialect = Dialect.get("fakedb")
31+
32+
self.assertIsNotNone(dialect)
33+
self.assertEqual(dialect, FakeDialect)
34+
fake_entry_point.load.assert_called_once()
35+
mock_selectable.select.assert_called_once_with(group="sqlglot.dialects", name="fakedb")
36+
37+
def test_entry_point_plugin_discovery_legacy_api(self):
38+
fake_entry_point = Mock()
39+
fake_entry_point.name = "fakedb"
40+
fake_entry_point.load.return_value = FakeDialect
41+
42+
mock_dict = Mock(spec=["get"])
43+
mock_dict.get.return_value = [fake_entry_point]
44+
45+
mock_entry_points = Mock(return_value=mock_dict)
46+
47+
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
48+
dialect = Dialect.get("fakedb")
49+
50+
self.assertIsNotNone(dialect)
51+
self.assertEqual(dialect, FakeDialect)
52+
fake_entry_point.load.assert_called_once()
53+
mock_dict.get.assert_called_once_with("sqlglot.dialects", [])
54+
55+
def test_entry_point_plugin_not_found(self):
56+
mock_selectable = Mock()
57+
mock_selectable.select.return_value = []
58+
59+
mock_entry_points = Mock(return_value=mock_selectable)
60+
61+
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
62+
dialect = Dialect.get("nonexistent")
63+
64+
self.assertIsNone(dialect)

0 commit comments

Comments
 (0)