diff --git a/deepnote_toolkit/__init__.py b/deepnote_toolkit/__init__.py index 3968878..495a782 100644 --- a/deepnote_toolkit/__init__.py +++ b/deepnote_toolkit/__init__.py @@ -34,6 +34,7 @@ (".set_notebook_path", "set_notebook_path"), (".sql.sql_execution", "execute_sql"), (".sql.sql_execution", "execute_sql_with_connection_json"), + (".sql.sql_query_chaining", "register_sql_query"), (".variable_explorer", "deepnote_export_df"), (".variable_explorer", "deepnote_get_data_preview_json"), (".variable_explorer", "get_var_list"), diff --git a/deepnote_toolkit/sql/sql_query_chaining.py b/deepnote_toolkit/sql/sql_query_chaining.py index ec4a20e..27d000c 100644 --- a/deepnote_toolkit/sql/sql_query_chaining.py +++ b/deepnote_toolkit/sql/sql_query_chaining.py @@ -5,6 +5,30 @@ from deepnote_toolkit.sql.query_preview import DeepnoteQueryPreview from deepnote_toolkit.sql.sql_utils import is_single_select_query +# Maps a SQL block's result variable name to the source query that produced it. +# deepnote-internal calls register_sql_query after a dataframe-mode SQL block +# runs, so that subsequent SQL blocks can reference the variable as a table and +# have it expanded into a CTE. This is the same chaining DeepnoteQueryPreview +# enables for query-preview-mode blocks, but without changing the result type +# (which would show a preview banner and limit the row count). +_sql_query_registry: dict = {} + + +def register_sql_query(variable_name, query): + """Register a SQL block's source query so downstream blocks can chain on it. + + Only single SELECT statements can be expanded into a CTE, so anything else + (or an empty query) clears any previous entry for the variable rather than + storing it. + """ + if variable_name is None: + return + + if query is not None and is_single_select_query(query): + _sql_query_registry[variable_name] = query + else: + _sql_query_registry.pop(variable_name, None) + def add_limit_clause(query: str, limit: int = 100): class ExecuteSqlError(Exception): @@ -137,16 +161,33 @@ def find_query_preview_references( # Check each table reference for table_reference in table_references: - # Check if the reference exists in the main module + variable_name = table_reference + + # Skip references we've already resolved (dedupe + cycle guard by name) + if variable_name in query_preview_references: + continue + + # 1) Dataframe-mode chaining. deepnote-internal registers a block's + # source query keyed by its result variable name. We only trust the + # entry while the variable is still bound in __main__, so a deleted or + # rebound variable doesn't produce a stale CTE expansion. + if variable_name in _sql_query_registry and hasattr(__main__, variable_name): + registered_source = _sql_query_registry[variable_name] + if registered_source: + query_preview_references[variable_name] = registered_source + # Recursively resolve the registered query's own references + find_query_preview_references( + registered_source, + query_preview_references, + processed_queries, + ) + continue + + # 2) Query-preview-mode chaining. The variable itself is a + # DeepnoteQueryPreview that carries its source query. if hasattr(__main__, table_reference): - variable_name = table_reference variable = getattr(__main__, table_reference) - # If it's a QueryPreview object and not already in our list - # Use any() with a generator expression to check if the variable is already in the list - # This avoids using the pandas object in a boolean context - if isinstance(variable, DeepnoteQueryPreview) and not any( - id(variable) == id(ref) for ref in query_preview_references - ): + if isinstance(variable, DeepnoteQueryPreview): # Add it to our list query_preview_source = variable._deepnote_query query_preview_references[variable_name] = query_preview_source diff --git a/tests/unit/test_sql_query_chaining.py b/tests/unit/test_sql_query_chaining.py index 59f30d9..321efdc 100644 --- a/tests/unit/test_sql_query_chaining.py +++ b/tests/unit/test_sql_query_chaining.py @@ -1,13 +1,16 @@ from unittest import TestCase +import __main__ import sqlparse from sqlparse.tokens import Keyword +from deepnote_toolkit.sql import sql_query_chaining from deepnote_toolkit.sql.sql_query_chaining import ( add_limit_clause, extract_table_reference_from_token, extract_table_references, find_query_preview_references, + register_sql_query, unchain_sql_query, ) @@ -714,3 +717,80 @@ def test_non_single_select_query_raises_exception(self): "Invalid query type: Query Preview supports only a single SELECT statement", str(context.exception), ) + + +class TestRegistryBasedChaining(TestCase): + """Tests for dataframe-mode chaining via register_sql_query / the registry. + + deepnote-internal registers a SQL block's source query keyed by its result + variable name; downstream blocks then reference the variable as a table and + it is expanded into a CTE - without the result becoming a DeepnoteQueryPreview. + """ + + def setUp(self): + sql_query_chaining._sql_query_registry.clear() + self._original_vars = vars(__main__).copy() + + def tearDown(self): + sql_query_chaining._sql_query_registry.clear() + for key in list(vars(__main__).keys()): + if key not in self._original_vars: + delattr(__main__, key) + + def _bind(self, name, value="bound"): + """Bind a variable in __main__ (simulates the block having executed).""" + setattr(__main__, name, value) + + def test_register_stores_single_select(self): + register_sql_query("df_1", "SELECT * FROM users") + self.assertEqual( + sql_query_chaining._sql_query_registry, {"df_1": "SELECT * FROM users"} + ) + + def test_register_non_select_clears_entry(self): + register_sql_query("df_1", "SELECT * FROM users") + register_sql_query("df_1", "INSERT INTO users VALUES (1)") + self.assertNotIn("df_1", sql_query_chaining._sql_query_registry) + + def test_register_none_query_clears_entry(self): + register_sql_query("df_1", "SELECT * FROM users") + register_sql_query("df_1", None) + self.assertNotIn("df_1", sql_query_chaining._sql_query_registry) + + def test_find_references_via_registry(self): + self._bind("df_1") + register_sql_query("df_1", "SELECT * FROM users") + refs = find_query_preview_references("SELECT * FROM df_1") + self.assertEqual(refs, {"df_1": "SELECT * FROM users"}) + + def test_unchain_via_registry(self): + self._bind("df_1") + register_sql_query("df_1", "SELECT * FROM users") + result = unchain_sql_query("SELECT * FROM df_1") + self.assertIn("WITH", result) + self.assertIn("df_1 AS", result) + self.assertIn("SELECT * FROM users", result) + + def test_multi_level_registry_chaining(self): + self._bind("df_1") + self._bind("df_2") + register_sql_query("df_1", "SELECT * FROM users") + register_sql_query("df_2", "SELECT * FROM df_1 WHERE active") + result = unchain_sql_query("SELECT * FROM df_2") + # Both levels expanded, with the dependency (df_1) defined before df_2 + self.assertIn("df_1 AS", result) + self.assertIn("df_2 AS", result) + self.assertIn("SELECT * FROM users", result) + self.assertLess(result.find("df_1 AS"), result.find("df_2 AS")) + + def test_stale_entry_ignored_when_variable_not_bound(self): + # Registered but the variable is no longer bound in __main__ (e.g. deleted + # cell or kernel restart) -> the entry must not produce a CTE. + register_sql_query("df_1", "SELECT * FROM users") + refs = find_query_preview_references("SELECT * FROM df_1") + self.assertEqual(refs, {}) + + def test_empty_registry_is_safe(self): + refs = find_query_preview_references("SELECT * FROM df_1") + self.assertEqual(refs, {}) + self.assertEqual(unchain_sql_query("SELECT * FROM df_1"), "SELECT * FROM df_1")