Skip to content

Commit 3945acc

Browse files
authored
Feat: allow tables to be preserved in replace_table (#4468)
* Feat: allow tables to be preserved in replace_table * PR feedback * PR feedback
1 parent 63d8f41 commit 3945acc

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

sqlglot/dialects/dialect.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlglot import exp
99
from sqlglot.errors import ParseError
1010
from sqlglot.generator import Generator, unsupported_args
11-
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
11+
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool
1212
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
1313
from sqlglot.parser import Parser
1414
from sqlglot.time import TIMEZONES, format_time, subsecond_precision
@@ -770,14 +770,7 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect:
770770
elif len(pair) == 2:
771771
value = pair[1].strip()
772772

773-
# Coerce the value to boolean if it matches to the truthy/falsy values below
774-
value_lower = value.lower()
775-
if value_lower in ("true", "1"):
776-
value = True
777-
elif value_lower in ("false", "0"):
778-
value = False
779-
780-
kwargs[key] = value
773+
kwargs[key] = to_bool(value)
781774

782775
except ValueError:
783776
raise ValueError(

sqlglot/expressions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
ensure_list,
3232
seq_get,
3333
subclasses,
34+
to_bool,
3435
)
3536
from sqlglot.tokens import Token, TokenError
3637

@@ -312,7 +313,7 @@ def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool =
312313
for kv in "".join(meta).split(","):
313314
k, *v = kv.split("=")
314315
value = v[0].strip() if v else True
315-
self.meta[k.strip()] = value
316+
self.meta[k.strip()] = to_bool(value)
316317

317318
if not prepend:
318319
self.comments.append(comment)
@@ -8231,7 +8232,7 @@ def replace_tables(
82318232
mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()}
82328233

82338234
def _replace_tables(node: Expression) -> Expression:
8234-
if isinstance(node, Table):
8235+
if isinstance(node, Table) and node.meta.get("replace") is not False:
82358236
original = normalize_table_name(node, dialect=dialect)
82368237
new_name = mapping.get(original)
82378238

sqlglot/helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,20 @@ def first(it: t.Iterable[T]) -> T:
456456
return next(i for i in it)
457457

458458

459+
def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
460+
if isinstance(value, bool) or value is None:
461+
return value
462+
463+
# Coerce the value to boolean if it matches to the truthy/falsy values below
464+
value_lower = value.lower()
465+
if value_lower in ("true", "1"):
466+
return True
467+
if value_lower in ("false", "0"):
468+
return False
469+
470+
return value
471+
472+
459473
def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
460474
"""
461475
Merges a sequence of ranges, represented as tuples (low, high) whose values

tests/test_expressions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,14 @@ def test_replace_tables(self):
258258
'SELECT * FROM "my-project"."example"."table" /* example.table */',
259259
)
260260

261+
self.assertEqual(
262+
exp.replace_tables(
263+
parse_one("select * from example.table /* sqlglot.meta replace=false */"),
264+
{"example.table": "a.b"},
265+
).sql(),
266+
"SELECT * FROM example.table /* sqlglot.meta replace=false */",
267+
)
268+
261269
def test_expand(self):
262270
self.assertEqual(
263271
exp.expand(
@@ -1168,7 +1176,7 @@ def test_is_type(self):
11681176

11691177
def test_set_meta(self):
11701178
query = parse_one("SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */")
1171-
self.assertEqual(query.find(exp.Table).meta, {"x": "1", "y": "a", "z": True})
1179+
self.assertEqual(query.find(exp.Table).meta, {"x": True, "y": "a", "z": True})
11721180
self.assertEqual(query.sql(), "SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */")
11731181

11741182
def test_assert_is(self):

0 commit comments

Comments
 (0)