Skip to content

Commit de67d3c

Browse files
authored
Chore!: error on unsupported dialect settings (#5119)
* Chore: warn on unsupported dialect settings * PR feedback
1 parent fddd24a commit de67d3c

File tree

4 files changed

+59
-18
lines changed

4 files changed

+59
-18
lines changed

sqlglot/dialects/dialect.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212
from sqlglot.dialects import DIALECT_MODULE_NAMES
1313
from sqlglot.errors import ParseError
1414
from sqlglot.generator import Generator, unsupported_args
15-
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool
15+
from sqlglot.helper import (
16+
AutoName,
17+
flatten,
18+
is_int,
19+
seq_get,
20+
subclasses,
21+
suggest_closest_match_and_fail,
22+
to_bool,
23+
)
1624
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
1725
from sqlglot.parser import Parser
1826
from sqlglot.time import TIMEZONES, format_time, subsecond_precision
@@ -794,6 +802,12 @@ class Dialect(metaclass=_Dialect):
794802
# Specifies what types a given type can be coerced into
795803
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
796804

805+
# Determines the supported Dialect instance settings
806+
SUPPORTED_SETTINGS = {
807+
"normalization_strategy",
808+
"version",
809+
}
810+
797811
@classmethod
798812
def get_or_raise(cls, dialect: DialectType) -> Dialect:
799813
"""
@@ -843,16 +857,9 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect:
843857

844858
result = cls.get(dialect_name.strip())
845859
if not result:
846-
from difflib import get_close_matches
847-
848-
close_matches = get_close_matches(dialect_name, list(DIALECT_MODULE_NAMES), n=1)
849-
850-
similar = seq_get(close_matches, 0) or ""
851-
if similar:
852-
similar = f" Did you mean {similar}?"
853-
854-
raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
860+
suggest_closest_match_and_fail("dialect", dialect_name, list(DIALECT_MODULE_NAMES))
855861

862+
assert result is not None
856863
return result(**kwargs)
857864

858865
raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
@@ -874,15 +881,19 @@ def format_time(
874881
return expression
875882

876883
def __init__(self, **kwargs) -> None:
877-
normalization_strategy = kwargs.pop("normalization_strategy", None)
884+
self.version = Version(kwargs.pop("version", None))
878885

886+
normalization_strategy = kwargs.pop("normalization_strategy", None)
879887
if normalization_strategy is None:
880888
self.normalization_strategy = self.NORMALIZATION_STRATEGY
881889
else:
882890
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
883891

884892
self.settings = kwargs
885893

894+
for unsupported_setting in kwargs.keys() - self.SUPPORTED_SETTINGS:
895+
suggest_closest_match_and_fail("setting", unsupported_setting, self.SUPPORTED_SETTINGS)
896+
886897
def __eq__(self, other: t.Any) -> bool:
887898
# Does not currently take dialect state into account
888899
return type(self) == other
@@ -1026,10 +1037,6 @@ def parser(self, **opts) -> Parser:
10261037
def generator(self, **opts) -> Generator:
10271038
return self.generator_class(dialect=self, **opts)
10281039

1029-
@property
1030-
def version(self) -> Version:
1031-
return Version(self.settings.get("version", None))
1032-
10331040
def generate_values_aliases(self, expression: exp.Values) -> t.List[exp.Identifier]:
10341041
return [
10351042
exp.to_identifier(f"_col_{i}")

sqlglot/dialects/presto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ class Presto(Dialect):
281281
else self._set_type(e, exp.DataType.Type.DOUBLE),
282282
}
283283

284+
SUPPORTED_SETTINGS = {
285+
*Dialect.SUPPORTED_SETTINGS,
286+
"variant_extract_is_json_extract",
287+
}
288+
284289
class Tokenizer(tokens.Tokenizer):
285290
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
286291
UNICODE_STRINGS = [

sqlglot/helper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections.abc import Collection, Set
1010
from contextlib import contextmanager
1111
from copy import copy
12+
from difflib import get_close_matches
1213
from enum import Enum
1314
from itertools import count
1415

@@ -45,6 +46,20 @@ def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
4546
return classmethod(self.fget).__get__(None, owner)() # type: ignore
4647

4748

49+
def suggest_closest_match_and_fail(
50+
kind: str,
51+
word: str,
52+
possibilities: t.Iterable[str],
53+
) -> None:
54+
close_matches = get_close_matches(word, possibilities, n=1)
55+
56+
similar = seq_get(close_matches, 0) or ""
57+
if similar:
58+
similar = f" Did you mean {similar}?"
59+
60+
raise ValueError(f"Unknown {kind} '{word}'.{similar}")
61+
62+
4863
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
4964
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
5065
try:

tests/dialects/test_dialect.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
parse_one,
1212
)
1313
from sqlglot.dialects import BigQuery, Hive, Snowflake
14+
from sqlglot.dialects.dialect import Version
1415
from sqlglot.parser import logger as parser_logger
1516

1617

@@ -134,14 +135,25 @@ def test_get_or_raise(self):
134135
"oracle, normalization_strategy = lowercase, version = 19.5"
135136
)
136137
self.assertEqual(oracle_with_settings.normalization_strategy.value, "LOWERCASE")
137-
self.assertEqual(oracle_with_settings.settings, {"version": "19.5"})
138+
self.assertEqual(oracle_with_settings.version, Version("19.5"))
138139

139-
bool_settings = Dialect.get_or_raise("oracle, s1=TruE, s2=1, s3=FaLse, s4=0, s5=nonbool")
140+
class MyDialect(Dialect):
141+
SUPPORTED_SETTINGS = {"s1", "s2", "s3", "s4", "s5"}
142+
143+
bool_settings = Dialect.get_or_raise("mydialect, s1=TruE, s2=1, s3=FaLse, s4=0, s5=nonbool")
140144
self.assertEqual(
141145
bool_settings.settings,
142146
{"s1": True, "s2": True, "s3": False, "s4": False, "s5": "nonbool"},
143147
)
144148

149+
with self.assertRaises(ValueError) as cm:
150+
Dialect.get_or_raise("tsql,normalisation_strategy=case_sensitive")
151+
152+
self.assertEqual(
153+
"Unknown setting 'normalisation_strategy'. Did you mean normalization_strategy?",
154+
str(cm.exception),
155+
)
156+
145157
def test_compare_dialects(self):
146158
bigquery_class = Dialect["bigquery"]
147159
bigquery_object = BigQuery()
@@ -170,7 +182,9 @@ def test_compare_dialects(self):
170182

171183
def test_compare_dialect_versions(self):
172184
ddb_v1 = Dialect.get_or_raise("duckdb, version=1.0")
173-
ddb_v1_2 = Dialect.get_or_raise("duckdb, foo=bar, version=1.0")
185+
ddb_v1_2 = Dialect.get_or_raise(
186+
"duckdb, normalization_strategy=case_sensitive, version=1.0"
187+
)
174188
ddb_v2 = Dialect.get_or_raise("duckdb, version=2.2.4")
175189
ddb_latest = Dialect.get_or_raise("duckdb")
176190

0 commit comments

Comments
 (0)