|
27 | 27 | from sqlglot.trie import new_trie |
28 | 28 | from sqlglot.typing import EXPRESSION_METADATA |
29 | 29 |
|
| 30 | +from importlib.metadata import entry_points |
| 31 | + |
30 | 32 | DATE_ADD_OR_DIFF = t.Union[ |
31 | 33 | exp.DateAdd, |
32 | 34 | exp.DateDiff, |
|
66 | 68 | "\\\\": "\\", |
67 | 69 | } |
68 | 70 |
|
| 71 | +PLUGIN_GROUP_NAME = "sqlglot.dialects" |
| 72 | + |
69 | 73 |
|
70 | 74 | class Dialects(str, Enum): |
71 | 75 | """Dialects supported by SQLGLot.""" |
@@ -153,12 +157,54 @@ def _try_load(cls, key: str | Dialects) -> None: |
153 | 157 | if isinstance(key, Dialects): |
154 | 158 | key = key.value |
155 | 159 |
|
156 | | - # This import will lead to a new dialect being loaded, and hence, registered. |
157 | | - # We check that the key is an actual sqlglot module to avoid blindly importing |
158 | | - # files. Custom user dialects need to be imported at the top-level package, in |
159 | | - # order for them to be registered as soon as possible. |
| 160 | + # 1. Try standard sqlglot modules first |
160 | 161 | if key in DIALECT_MODULE_NAMES: |
| 162 | + module = importlib.import_module(f"sqlglot.dialects.{key}") |
| 163 | + # If module was already imported, the class may not be in _classes |
| 164 | + # Find and register the dialect class from the module |
| 165 | + if key not in cls._classes: |
| 166 | + for attr_name in dir(module): |
| 167 | + attr = getattr(module, attr_name, None) |
| 168 | + if ( |
| 169 | + isinstance(attr, type) |
| 170 | + and issubclass(attr, Dialect) |
| 171 | + and attr.__name__.lower() == key |
| 172 | + ): |
| 173 | + cls._classes[key] = attr |
| 174 | + break |
| 175 | + return |
| 176 | + |
| 177 | + # 2. Try entry points (for plugins) |
| 178 | + try: |
| 179 | + all_eps = entry_points() |
| 180 | + # Python 3.10+ has select() method, older versions use dict-like access |
| 181 | + if hasattr(all_eps, "select"): |
| 182 | + eps = all_eps.select(group=PLUGIN_GROUP_NAME, name=key) |
| 183 | + else: |
| 184 | + # For older Python versions, entry_points() returns a dict-like object |
| 185 | + group_eps = all_eps.get(PLUGIN_GROUP_NAME, []) # type: ignore |
| 186 | + eps = [ep for ep in group_eps if ep.name == key] # type: ignore |
| 187 | + |
| 188 | + for entry_point in eps: |
| 189 | + dialect_class = entry_point.load() |
| 190 | + # Verify it's a Dialect subclass |
| 191 | + # issubclass() returns False if not a subclass, TypeError only if not a class at all |
| 192 | + if isinstance(dialect_class, type) and issubclass(dialect_class, Dialect): |
| 193 | + # Register the dialect using the entry point name (key) |
| 194 | + # The metaclass may have registered it by class name, but we need it by entry point name |
| 195 | + if key not in cls._classes: |
| 196 | + cls._classes[key] = dialect_class |
| 197 | + return |
| 198 | + except ImportError: |
| 199 | + # entry_point.load() failed (bad plugin - module/class doesn't exist) |
| 200 | + pass |
| 201 | + |
| 202 | + # 3. Try direct import (for backward compatibility) |
| 203 | + # This allows namespace packages or explicit imports to work |
| 204 | + try: |
161 | 205 | importlib.import_module(f"sqlglot.dialects.{key}") |
| 206 | + except ImportError: |
| 207 | + pass |
162 | 208 |
|
163 | 209 | @classmethod |
164 | 210 | def __getitem__(cls, key: str) -> t.Type[Dialect]: |
@@ -947,7 +993,9 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect: |
947 | 993 |
|
948 | 994 | result = cls.get(dialect_name.strip()) |
949 | 995 | if not result: |
950 | | - suggest_closest_match_and_fail("dialect", dialect_name, list(DIALECT_MODULE_NAMES)) |
| 996 | + # Include both built-in dialects and any loaded dialects for better error messages |
| 997 | + all_dialects = set(DIALECT_MODULE_NAMES) | set(cls._classes.keys()) |
| 998 | + suggest_closest_match_and_fail("dialect", dialect_name, all_dialects) |
951 | 999 |
|
952 | 1000 | assert result is not None |
953 | 1001 | return result(**kwargs) |
|
0 commit comments