Skip to content

Commit 455dfdb

Browse files
authored
fix(view_loader): avoid lower casing the column name (#1703)
* fix(view_loader): avoid lower casing the column name * fix: remove print statements
1 parent ccdf77f commit 455dfdb

File tree

4 files changed

+159
-149
lines changed

4 files changed

+159
-149
lines changed

pandasai/helpers/sql_sanitizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22
import re
33

44
import sqlglot
5+
from sqlglot import parse_one
6+
from sqlglot.optimizer.qualify_columns import quote_identifiers
57

68

79
def sanitize_view_column_name(relation_name: str) -> str:
8-
return ".".join(list(map(sanitize_sql_table_name, relation_name.split("."))))
10+
return (
11+
parse_one(
12+
".".join(list(map(sanitize_sql_table_name, relation_name.split("."))))
13+
)
14+
.transform(quote_identifiers)
15+
.sql()
16+
)
917

1018

1119
def sanitize_sql_table_name(table_name: str) -> str:

pandasai/query_builders/view_query_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ def __init__(
2424

2525
@staticmethod
2626
def normalize_view_column_name(name: str) -> str:
27-
return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql()
27+
return sanitize_view_column_name(name)
2828

2929
@staticmethod
3030
def normalize_view_column_alias(name: str) -> str:
31-
return normalize_identifiers(
32-
sanitize_view_column_name(name).replace(".", "_")
33-
).sql()
31+
col_name = name.replace(".", "_")
32+
return sanitize_view_column_name(col_name)
3433

3534
def _get_group_by_columns(self) -> list[str]:
3635
"""Get the group by columns with proper view column aliasing."""

tests/unit_tests/helpers/test_sql_sanitizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self):
2525

2626
def test_sanitize_relation_name_valid(self):
2727
relation = "dataset-name.column"
28-
expected = "dataset_name.column"
28+
expected = '"dataset_name"."column"'
2929
assert sanitize_view_column_name(relation) == expected
3030

3131
def test_safe_select_query(self):

0 commit comments

Comments
 (0)