Skip to content

Commit df5ecdb

Browse files
authored
Feat: Include token refereces in the meta of identifier expressions (#5022)
1 parent c1c892c commit df5ecdb

File tree

5 files changed

+162
-12
lines changed

5 files changed

+162
-12
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -645,14 +645,16 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
645645

646646
table_name += self._find_sql(start, self._prev)
647647

648-
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
648+
this = exp.Identifier(
649+
this=table_name, quoted=this.args.get("quoted")
650+
).update_positions(this)
649651
elif isinstance(this, exp.Literal):
650652
table_name = this.name
651653

652654
if self._is_connected() and self._parse_var(any_token=True):
653655
table_name += self._prev.text
654656

655-
this = exp.Identifier(this=table_name, quoted=True)
657+
this = exp.Identifier(this=table_name, quoted=True).update_positions(this)
656658

657659
return this
658660

@@ -666,15 +668,23 @@ def _parse_table_parts(
666668
# proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here
667669
if not table.catalog:
668670
if table.db:
671+
previous_db = table.args["db"]
669672
parts = table.db.split(".")
670673
if len(parts) == 2 and not table.args["db"].quoted:
671-
table.set("catalog", exp.Identifier(this=parts[0]))
672-
table.set("db", exp.Identifier(this=parts[1]))
674+
table.set(
675+
"catalog", exp.Identifier(this=parts[0]).update_positions(previous_db)
676+
)
677+
table.set("db", exp.Identifier(this=parts[1]).update_positions(previous_db))
673678
else:
679+
previous_this = table.this
674680
parts = table.name.split(".")
675681
if len(parts) == 2 and not table.this.quoted:
676-
table.set("db", exp.Identifier(this=parts[0]))
677-
table.set("this", exp.Identifier(this=parts[1]))
682+
table.set(
683+
"db", exp.Identifier(this=parts[0]).update_positions(previous_this)
684+
)
685+
table.set(
686+
"this", exp.Identifier(this=parts[1]).update_positions(previous_this)
687+
)
678688

679689
if isinstance(table.this, exp.Identifier) and any("." in p.name for p in table.parts):
680690
alias = table.this
@@ -683,6 +693,10 @@ def _parse_table_parts(
683693
for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
684694
)
685695

696+
for part in (catalog, db, this):
697+
if part:
698+
part.update_positions(table.this)
699+
686700
if rest and this:
687701
this = exp.Dot.build([this, *rest]) # type: ignore
688702

@@ -717,7 +731,13 @@ def _parse_table_parts(
717731
)
718732

719733
info_schema_view = f"{table_parts[-2].name}.{table_parts[-1].name}"
720-
table.set("this", exp.Identifier(this=info_schema_view, quoted=True))
734+
new_this = exp.Identifier(this=info_schema_view, quoted=True).update_positions(
735+
line=table_parts[-2].meta.get("line"),
736+
col=table_parts[-1].meta.get("col"),
737+
start=table_parts[-2].meta.get("start"),
738+
end=table_parts[-1].meta.get("end"),
739+
)
740+
table.set("this", new_this)
721741
table.set("db", seq_get(table_parts, -3))
722742
table.set("catalog", seq_get(table_parts, -4))
723743

sqlglot/expressions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __new__(cls, clsname, bases, attrs):
6464
SQLGLOT_ANONYMOUS = "sqlglot.anonymous"
6565
TABLE_PARTS = ("this", "db", "catalog")
6666
COLUMN_PARTS = ("this", "table", "db", "catalog")
67+
POSITION_META_KEYS = ("line", "col", "start", "end")
6768

6869

6970
class Expression(metaclass=_Expression):
@@ -846,6 +847,32 @@ def not_(self, copy: bool = True):
846847
"""
847848
return not_(self, copy=copy)
848849

850+
def update_positions(
851+
self: E, other: t.Optional[Token | Expression] = None, **kwargs: t.Any
852+
) -> E:
853+
"""
854+
Update this expression with positions from a token or other expression.
855+
856+
Args:
857+
other: a token or expression to update this expression with.
858+
859+
Returns:
860+
The updated expression.
861+
"""
862+
if isinstance(other, Expression):
863+
self.meta.update({k: v for k, v in other.meta.items() if k in POSITION_META_KEYS})
864+
elif other is not None:
865+
self.meta.update(
866+
{
867+
"line": other.line,
868+
"col": other.col,
869+
"start": other.start,
870+
"end": other.end,
871+
}
872+
)
873+
self.meta.update({k: v for k, v in kwargs.items() if k in POSITION_META_KEYS})
874+
return self
875+
849876
def as_(
850877
self,
851878
alias: str | Identifier,

sqlglot/parser.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5618,6 +5618,7 @@ def _parse_function_call(
56185618
return None
56195619

56205620
comments = self._curr.comments
5621+
token = self._curr
56215622
token_type = self._curr.token_type
56225623
this = self._curr.text
56235624
upper = this.upper()
@@ -5692,7 +5693,7 @@ def _parse_function_call(
56925693
this = func
56935694
else:
56945695
if token_type == TokenType.IDENTIFIER:
5695-
this = exp.Identifier(this=this, quoted=True)
5696+
this = exp.Identifier(this=this, quoted=True).update_positions(token)
56965697
this = self.expression(exp.Anonymous, this=this, expressions=args)
56975698

56985699
if isinstance(this, exp.Expression):
@@ -5751,7 +5752,7 @@ def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier:
57515752
if literal:
57525753
return self.expression(exp.Introducer, this=token.text, expression=literal)
57535754

5754-
return self.expression(exp.Identifier, this=token.text)
5755+
return self._identifier_expression(token)
57555756

57565757
def _parse_session_parameter(self) -> exp.SessionParameter:
57575758
kind = None
@@ -6981,7 +6982,7 @@ def _parse_id_var(
69816982
(any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS)
69826983
):
69836984
quoted = self._prev.token_type == TokenType.STRING
6984-
expression = self.expression(exp.Identifier, this=self._prev.text, quoted=quoted)
6985+
expression = self._identifier_expression(quoted=quoted)
69856986

69866987
return expression
69876988

@@ -6991,7 +6992,10 @@ def _parse_string(self) -> t.Optional[exp.Expression]:
69916992
return self._parse_placeholder()
69926993

69936994
def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
6994-
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
6995+
output = exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
6996+
if output:
6997+
output.update_positions(self._prev)
6998+
return output
69956999

69967000
def _parse_number(self) -> t.Optional[exp.Expression]:
69977001
if self._match_set(self.NUMERIC_PARSERS):
@@ -7000,7 +7004,7 @@ def _parse_number(self) -> t.Optional[exp.Expression]:
70007004

70017005
def _parse_identifier(self) -> t.Optional[exp.Expression]:
70027006
if self._match(TokenType.IDENTIFIER):
7003-
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
7007+
return self._identifier_expression(quoted=True)
70047008
return self._parse_placeholder()
70057009

70067010
def _parse_var(
@@ -8246,3 +8250,11 @@ def _parse_max_min_by(self, expr_type: t.Type[exp.AggFunc]) -> exp.AggFunc:
82468250
return self.expression(
82478251
expr_type, this=seq_get(args, 0), expression=seq_get(args, 1), count=seq_get(args, 2)
82488252
)
8253+
8254+
def _identifier_expression(
8255+
self, token: t.Optional[Token] = None, **kwargs: t.Any
8256+
) -> exp.Identifier:
8257+
token = token or self._prev
8258+
expression = self.expression(exp.Identifier, this=token.text, **kwargs)
8259+
expression.update_positions(token)
8260+
return expression

tests/dialects/test_bigquery.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,3 +2432,56 @@ def test_with_offset(self):
24322432
f"SELECT * FROM t1, UNNEST([1, 2]) AS hit WITH OFFSET {join_ops} JOIN foo",
24332433
f"SELECT * FROM t1, UNNEST([1, 2]) AS hit WITH OFFSET AS offset {join_ops} JOIN foo",
24342434
)
2435+
2436+
def test_identifier_meta(self):
2437+
ast = parse_one(
2438+
"SELECT a, b FROM test_schema.test_table_a UNION ALL SELECT c, d FROM test_catalog.test_schema.test_table_b",
2439+
dialect="bigquery",
2440+
)
2441+
for identifier in ast.find_all(exp.Identifier):
2442+
self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"})
2443+
2444+
self.assertEqual(
2445+
ast.this.args["from"].this.args["this"].meta,
2446+
{"line": 1, "col": 41, "start": 29, "end": 40},
2447+
)
2448+
self.assertEqual(
2449+
ast.this.args["from"].this.args["db"].meta,
2450+
{"line": 1, "col": 28, "start": 17, "end": 27},
2451+
)
2452+
self.assertEqual(
2453+
ast.expression.args["from"].this.args["this"].meta,
2454+
{"line": 1, "col": 106, "start": 94, "end": 105},
2455+
)
2456+
self.assertEqual(
2457+
ast.expression.args["from"].this.args["db"].meta,
2458+
{"line": 1, "col": 93, "start": 82, "end": 92},
2459+
)
2460+
self.assertEqual(
2461+
ast.expression.args["from"].this.args["catalog"].meta,
2462+
{"line": 1, "col": 81, "start": 69, "end": 80},
2463+
)
2464+
2465+
information_schema_sql = "SELECT a, b FROM region.INFORMATION_SCHEMA.COLUMNS"
2466+
ast = parse_one(information_schema_sql, dialect="bigquery")
2467+
meta = ast.args["from"].this.this.meta
2468+
self.assertEqual(meta, {"line": 1, "col": 50, "start": 24, "end": 49})
2469+
assert (
2470+
information_schema_sql[meta["start"] : meta["end"] + 1] == "INFORMATION_SCHEMA.COLUMNS"
2471+
)
2472+
2473+
def test_quoted_identifier_meta(self):
2474+
sql = "SELECT `a` FROM `test_schema`.`test_table_a`"
2475+
ast = parse_one(sql, dialect="bigquery")
2476+
db_meta = ast.args["from"].this.args["db"].meta
2477+
self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], "`test_schema`")
2478+
table_meta = ast.args["from"].this.this.meta
2479+
self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], "`test_table_a`")
2480+
2481+
information_schema_sql = "SELECT a, b FROM `region.INFORMATION_SCHEMA.COLUMNS`"
2482+
ast = parse_one(information_schema_sql, dialect="bigquery")
2483+
table_meta = ast.args["from"].this.this.meta
2484+
assert (
2485+
information_schema_sql[table_meta["start"] : table_meta["end"] + 1]
2486+
== "`region.INFORMATION_SCHEMA.COLUMNS`"
2487+
)

tests/test_parser.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,3 +956,41 @@ def test_udf_meta(self):
956956
# Incomplete or incorrect anonymous meta comments are not registered
957957
ast = parse_one("YEAR(a) /* sqlglot.anon */")
958958
self.assertIsInstance(ast, exp.Year)
959+
960+
def test_identifier_meta(self):
961+
ast = parse_one(
962+
"SELECT a, b FROM test_schema.test_table_a UNION ALL SELECT c, d FROM test_catalog.test_schema.test_table_b"
963+
)
964+
for identifier in ast.find_all(exp.Identifier):
965+
self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"})
966+
967+
self.assertEqual(
968+
ast.this.args["from"].this.args["this"].meta,
969+
{"line": 1, "col": 41, "start": 29, "end": 40},
970+
)
971+
self.assertEqual(
972+
ast.this.args["from"].this.args["db"].meta,
973+
{"line": 1, "col": 28, "start": 17, "end": 27},
974+
)
975+
self.assertEqual(
976+
ast.expression.args["from"].this.args["this"].meta,
977+
{"line": 1, "col": 106, "start": 94, "end": 105},
978+
)
979+
self.assertEqual(
980+
ast.expression.args["from"].this.args["db"].meta,
981+
{"line": 1, "col": 93, "start": 82, "end": 92},
982+
)
983+
self.assertEqual(
984+
ast.expression.args["from"].this.args["catalog"].meta,
985+
{"line": 1, "col": 81, "start": 69, "end": 80},
986+
)
987+
988+
def test_quoted_identifier_meta(self):
989+
sql = 'SELECT "a" FROM "test_schema"."test_table_a"'
990+
ast = parse_one(sql)
991+
992+
db_meta = ast.args["from"].this.args["db"].meta
993+
self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], '"test_schema"')
994+
995+
table_meta = ast.args["from"].this.this.meta
996+
self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], '"test_table_a"')

0 commit comments

Comments
 (0)