File tree Expand file tree Collapse file tree 3 files changed +14
-8
lines changed
tests/unit_tests/core/code_generation Expand file tree Collapse file tree 3 files changed +14
-8
lines changed Original file line number Diff line number Diff line change @@ -55,7 +55,8 @@ def _clean_sql_query(self, sql_query: str) -> str:
5555 Clean the SQL query by trimming semicolons and validating table names.
5656 """
5757 sql_query = sql_query .rstrip (";" )
58- table_names = SQLParser .extract_table_names (sql_query )
58+ dialect = self .context .dfs [0 ].get_dialect ()
59+ table_names = SQLParser .extract_table_names (sql_query , dialect )
5960 allowed_table_names = {
6061 df .schema .name : df .schema .name for df in self .context .dfs
6162 } | {f'"{ df .schema .name } "' : df .schema .name for df in self .context .dfs }
Original file line number Diff line number Diff line change @@ -132,20 +132,23 @@ def rows_count(self) -> int:
132132 def columns_count (self ) -> int :
133133 return len (self .columns )
134134
135+ def get_dialect (self ):
136+ source = self .schema .source or None
137+ if source :
138+ dialect = "duckdb" if source .type in LOCAL_SOURCE_TYPES else source .type
139+ else :
140+ dialect = "postgres"
141+
142+ return dialect
143+
135144 def serialize_dataframe (self ) -> str :
136145 """
137146 Serialize DataFrame to string representation.
138147
139148 Returns:
140149 str: Serialized string representation of the DataFrame
141150 """
142- source = self .schema .source or None
143-
144- if source :
145- dialect = "duckdb" if source .type in LOCAL_SOURCE_TYPES else source .type
146- else :
147- dialect = "postgres"
148-
151+ dialect = self .get_dialect ()
149152 return DataframeSerializer .serialize (self , dialect )
150153
151154 def get_head (self ):
Original file line number Diff line number Diff line change @@ -71,6 +71,7 @@ def test_clean_sql_query(self):
7171 mock_dataframe .schema = MagicMock ()
7272 mock_dataframe .schema .name = "my_table"
7373 self .cleaner .context .dfs = [mock_dataframe ]
74+ mock_dataframe .get_dialect = MagicMock (return_value = "duckdb" )
7475 result = self .cleaner ._clean_sql_query (sql_query )
7576 self .assertEqual (result , "SELECT * FROM my_table" )
7677
@@ -84,6 +85,7 @@ def test_validate_and_make_table_name_case_sensitive(self):
8485 self .cleaner .context .dfs = [mock_dataframe ]
8586 mock_dataframe .schema = MagicMock ()
8687 mock_dataframe .schema .name = "my_table"
88+ mock_dataframe .get_dialect = MagicMock (return_value = "duckdb" )
8789 updated_node = self .cleaner ._validate_and_make_table_name_case_sensitive (node )
8890 self .assertEqual (updated_node .value .value , "SELECT * FROM my_table" )
8991
You can’t perform that action at this time.
0 commit comments