Skip to content

Commit 39ee3f5

Browse files
committed
add test + use single pass
1 parent 1609108 commit 39ee3f5

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/snowflake/snowpark/_internal/compiler/plan_compiler.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def replace_temp_obj_placeholders(
206206
placeholders = {}
207207
# Final execution queries
208208
execution_queries = {}
209-
for query_list in queries.values():
209+
for query_type, query_list in queries.items():
210+
execution_queries[query_type] = []
210211
for query in query_list:
211212
# If the query contains a temp object name placeholder, we generate a random
212213
# name for the temp object and add it to the placeholders dictionary.
@@ -215,13 +216,10 @@ def replace_temp_obj_placeholders(
215216
placeholder_name,
216217
temp_obj_type,
217218
) = query.temp_obj_name_placeholder
218-
placeholders[placeholder_name] = random_name_for_temp_object(
219-
temp_obj_type
220-
)
221-
# This loop must be done in a separate pass to ensure a CREATE/DROP pair actually refer to the same object.
222-
for query_type, query_list in queries.items():
223-
execution_queries[query_type] = []
224-
for query in query_list:
219+
if placeholder_name not in placeholders:
220+
placeholders[placeholder_name] = random_name_for_temp_object(
221+
temp_obj_type
222+
)
225223
copied_query = copy.copy(query)
226224
for placeholder_name, target_temp_name in placeholders.items():
227225
# Copy the original query and replace all the placeholder names with the

tests/integ/test_multithreading.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ def process_data(df_, thread_id):
848848
len(unique_create_table_queries) == expected_num_queries
849849
), queries_sent
850850
assert len(unique_drop_table_queries) == expected_num_queries, queries_sent
851+
assert unique_create_table_queries == unique_drop_table_queries
851852

852853
finally:
853854
analyzer.ARRAY_BIND_THRESHOLD = original_value
@@ -913,6 +914,7 @@ def process_data(df_, thread_id):
913914

914915
assert len(unique_create_file_format_queries) == 10
915916
assert len(unique_drop_file_format_queries) == 10
917+
assert unique_create_file_format_queries == unique_drop_file_format_queries
916918

917919

918920
@pytest.mark.xfail(

0 commit comments

Comments
 (0)