diff --git a/sqlanalyzer/column_parser.py b/sqlanalyzer/column_parser.py index 602c73f..81174c0 100644 --- a/sqlanalyzer/column_parser.py +++ b/sqlanalyzer/column_parser.py @@ -99,7 +99,7 @@ def get_table_names(self, line_query): for line in line_query: - table_line = re.findall(r"(FROM|JOIN).(\w+.*)", line) + table_line = re.findall(r"(?i)(FROM|JOIN).(\w*`?.*)", line) if table_line != []: table_name_line = table_line[0][1].split(' ') diff --git a/sqlanalyzer/tests/test_column_parser.py b/sqlanalyzer/tests/test_column_parser.py index 56418a3..33df7a4 100644 --- a/sqlanalyzer/tests/test_column_parser.py +++ b/sqlanalyzer/tests/test_column_parser.py @@ -12,6 +12,17 @@ def sample_query(): return query +@pytest.fixture +def sample_query_diff_dbs(): + query = """ + SELECT * + from `some_database.schema.table` + INNER JOIN some_schema.some_table + WHERE column IS NULL + """ + return query + + @pytest.fixture def formatter(sample_query): formatter = column_parser.Parser(sample_query) @@ -36,3 +47,13 @@ def test_get_table_names(sample_query, formatter): assert table_name_mapping == {'sfdc_accounts': 'sfdc.accounts', 'opportunity_to_name': 'opportunity_to_name'} + + +def test_get_table_names_diff_dbs(sample_query_diff_dbs): + formatter = column_parser.Parser(sample_query_diff_dbs) + formatted_query = formatter.format_query(sample_query_diff_dbs) + table_name_mapping = formatter.get_table_names(formatted_query.split('\n')) + + assert table_name_mapping == {'`some_database.schema.table`': '`some_database.schema.table`', + 'some_schema.some_table': 'some_schema.some_table'} +