|
4 | 4 | # |
5 | 5 |
|
6 | 6 | import pytest |
| 7 | +import re |
7 | 8 |
|
8 | 9 | from snowflake.snowpark._internal.utils import get_plan_from_line_numbers |
9 | 10 | from snowflake.snowpark import functions as F |
@@ -53,17 +54,6 @@ def generate_test_data(session, sql_simplifier_enabled): |
53 | 54 | 10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )', |
54 | 55 | }, |
55 | 56 | ), |
56 | | - ( |
57 | | - lambda data: data["df_join1"].join( |
58 | | - data["df_join2"], data["df_join1"].id == data["df_join2"].id |
59 | | - ), |
60 | | - True, |
61 | | - { |
62 | | - 2: 'SELECT * FROM ( ( SELECT "ID" AS "l_0000_ID", "NAME" AS "NAME" FROM ( SELECT $1 AS "ID", $2 AS "NAME" FROM VALUES (1 :: INT, \'A\' :: STRING), (2 :: INT, \'B\' :: STRING) ) ) AS SNOWPARK_LEFT INNER JOIN ( SELECT "ID" AS "r_0001_ID", "VALUE" AS "VALUE" FROM ( SELECT $1 AS "ID", $2 AS "VALUE" FROM VALUES (1 :: INT, 10 :: INT), (2 :: INT, 20 :: INT) ) ) AS SNOWPARK_RIGHT ON ("l_0000_ID" = "r_0001_ID") )', |
63 | | - 7: "SELECT $1 AS \"ID\", $2 AS \"NAME\" FROM VALUES (1 :: INT, 'A' :: STRING), (2 :: INT, 'B' :: STRING)", |
64 | | - 14: 'SELECT "ID" AS "r_0001_ID", "VALUE" AS "VALUE" FROM ( SELECT $1 AS "ID", $2 AS "VALUE" FROM VALUES (1 :: INT, 10 :: INT), (2 :: INT, 20 :: INT) )', |
65 | | - }, |
66 | | - ), |
67 | 57 | ( |
68 | 58 | lambda data: data["df1"].filter(data["df1"].value > 150), |
69 | 59 | True, |
@@ -117,6 +107,48 @@ def test_get_plan_from_line_numbers_sql_content( |
117 | 107 | ), f"Line {line_num}: Expected SQL '{expected_sql}' not equal to plan sql:\n{plan_sql}" |
118 | 108 |
|
119 | 109 |
|
| 110 | +def test_get_plan_from_line_numbers_join_operations(session): |
| 111 | + """ |
| 112 | + Test get_plan_from_line_numbers for join operations using regex matching. We don't check for |
| 113 | + the exact SQL b/c the join keys can vary in test environment |
| 114 | + """ |
| 115 | + session.sql_simplifier_enabled = True |
| 116 | + data = generate_test_data(session, True) |
| 117 | + |
| 118 | + df = data["df_join1"].join( |
| 119 | + data["df_join2"], data["df_join1"].id == data["df_join2"].id |
| 120 | + ) |
| 121 | + |
| 122 | + line_to_expected_pattern = { |
| 123 | + 2: r'SELECT \* FROM \(\(SELECT "ID" AS "l_\d+_ID", "NAME" AS "NAME" FROM \(SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)\)\) AS SNOWPARK_LEFT INNER JOIN \(SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)\) AS SNOWPARK_RIGHT ON \("l_\d+_ID" = "r_\d+_ID"\)\)', |
| 124 | + 7: r'SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)', |
| 125 | + 14: r'SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)', |
| 126 | + } |
| 127 | + |
| 128 | + for line_num, expected_pattern in line_to_expected_pattern.items(): |
| 129 | + plan = get_plan_from_line_numbers(df._plan, line_num) |
| 130 | + assert ( |
| 131 | + plan is not None |
| 132 | + ), f"get_plan_from_line_numbers returned None for line {line_num}" |
| 133 | + |
| 134 | + plan_sql = None |
| 135 | + if hasattr(plan, "queries") and plan.queries: |
| 136 | + plan_sql = plan.queries[-1].sql |
| 137 | + elif hasattr(plan, "sql_query") and plan.sql_query: |
| 138 | + plan_sql = plan.sql_query |
| 139 | + |
| 140 | + assert ( |
| 141 | + plan_sql is not None |
| 142 | + ), f"Could not extract SQL from plan for line {line_num}" |
| 143 | + |
| 144 | + normalized_sql = Utils.normalize_sql(plan_sql) |
| 145 | + assert re.match(expected_pattern, normalized_sql), ( |
| 146 | + f"Line {line_num}: SQL pattern does not match expected pattern.\n" |
| 147 | + f"Expected pattern: {expected_pattern}\n" |
| 148 | + f"Actual SQL: {normalized_sql}" |
| 149 | + ) |
| 150 | + |
| 151 | + |
120 | 152 | @pytest.mark.parametrize( |
121 | 153 | "line_num,expected_error", |
122 | 154 | [ |
|
0 commit comments