Skip to content

Commit b119213

Browse files
authored
fix(sql): dialect sql parser fixed (#1778)
1 parent a7324b5 commit b119213

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

pandasai/core/code_generation/code_cleaning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _clean_sql_query(self, sql_query: str) -> str:
5555
Clean the SQL query by trimming semicolons and validating table names.
5656
"""
5757
sql_query = sql_query.rstrip(";")
58-
table_names = SQLParser.extract_table_names(sql_query)
58+
dialect = self.context.dfs[0].get_dialect()
59+
table_names = SQLParser.extract_table_names(sql_query, dialect)
5960
allowed_table_names = {
6061
df.schema.name: df.schema.name for df in self.context.dfs
6162
} | {f'"{df.schema.name}"': df.schema.name for df in self.context.dfs}

pandasai/dataframe/base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,23 @@ def rows_count(self) -> int:
132132
def columns_count(self) -> int:
133133
return len(self.columns)
134134

135+
def get_dialect(self):
136+
source = self.schema.source or None
137+
if source:
138+
dialect = "duckdb" if source.type in LOCAL_SOURCE_TYPES else source.type
139+
else:
140+
dialect = "postgres"
141+
142+
return dialect
143+
135144
def serialize_dataframe(self) -> str:
136145
"""
137146
Serialize DataFrame to string representation.
138147
139148
Returns:
140149
str: Serialized string representation of the DataFrame
141150
"""
142-
source = self.schema.source or None
143-
144-
if source:
145-
dialect = "duckdb" if source.type in LOCAL_SOURCE_TYPES else source.type
146-
else:
147-
dialect = "postgres"
148-
151+
dialect = self.get_dialect()
149152
return DataframeSerializer.serialize(self, dialect)
150153

151154
def get_head(self):

tests/unit_tests/core/code_generation/test_code_cleaning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_clean_sql_query(self):
7171
mock_dataframe.schema = MagicMock()
7272
mock_dataframe.schema.name = "my_table"
7373
self.cleaner.context.dfs = [mock_dataframe]
74+
mock_dataframe.get_dialect = MagicMock(return_value="duckdb")
7475
result = self.cleaner._clean_sql_query(sql_query)
7576
self.assertEqual(result, "SELECT * FROM my_table")
7677

@@ -84,6 +85,7 @@ def test_validate_and_make_table_name_case_sensitive(self):
8485
self.cleaner.context.dfs = [mock_dataframe]
8586
mock_dataframe.schema = MagicMock()
8687
mock_dataframe.schema.name = "my_table"
88+
mock_dataframe.get_dialect = MagicMock(return_value="duckdb")
8789
updated_node = self.cleaner._validate_and_make_table_name_case_sensitive(node)
8890
self.assertEqual(updated_node.value.value, "SELECT * FROM my_table")
8991

0 commit comments

Comments
 (0)