Skip to content

Commit 92eaeb7

Browse files
committed
add tests
1 parent 2d88271 commit 92eaeb7

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

tests/integ/test_deepcopy.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import copy
66
from typing import Callable, List, Optional
7+
from unittest import mock
78

89
import pytest
910

@@ -415,3 +416,70 @@ def traverse_plan(plan, plan_id_map):
415416
traverse_plan(child, plan_id_map)
416417

417418
traverse_plan(copied_plan, {})
419+
420+
421+
def test_selectable_entity_deepcopy_attributes_with_flags_enabled(session):
422+
"""Verify _attributes is deepcopied in SelectableEntity when both flags are True."""
423+
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
424+
session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).write.save_as_table(
425+
temp_table_name, table_type="temp"
426+
)
427+
428+
with mock.patch.object(
429+
session, "_reduce_describe_query_enabled", True
430+
), mock.patch.object(session, "_cte_optimization_enabled", True):
431+
# Apply a filter to get a SelectStatement with SelectableEntity as child
432+
df = session.table(temp_table_name).filter(col("a") == 1)
433+
# Access the SelectableEntity from the plan's children
434+
if session.sql_simplifier_enabled:
435+
assert len(df._plan.children_plan_nodes) == 1
436+
assert isinstance(df._plan.children_plan_nodes[0], SelectableEntity)
437+
selectable = df._plan.children_plan_nodes[0]
438+
# Set _attributes to simulate cached attributes
439+
selectable._attributes = selectable.snowflake_plan.attributes
440+
441+
copied_selectable = copy.deepcopy(selectable)
442+
443+
# Verify attributes were deepcopied
444+
assert copied_selectable._attributes is not None
445+
assert copied_selectable._attributes is not selectable._attributes
446+
for copied_attr, original_attr in zip(
447+
copied_selectable._attributes, selectable._attributes
448+
):
449+
assert copied_attr is not original_attr
450+
assert copied_attr.name == original_attr.name
451+
452+
453+
@pytest.mark.parametrize(
454+
"reduce_describe_enabled,cte_enabled",
455+
[
456+
(False, False),
457+
(True, False),
458+
(False, True),
459+
],
460+
)
461+
def test_selectable_entity_deepcopy_attributes_with_flags_disabled(
462+
session, reduce_describe_enabled, cte_enabled
463+
):
464+
"""Verify _attributes is NOT copied when either flag is False."""
465+
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
466+
session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).write.save_as_table(
467+
temp_table_name, table_type="temp"
468+
)
469+
470+
with mock.patch.object(
471+
session, "_reduce_describe_query_enabled", reduce_describe_enabled
472+
), mock.patch.object(session, "_cte_optimization_enabled", cte_enabled):
473+
# Apply a filter to get a SelectStatement with SelectableEntity as child
474+
df = session.table(temp_table_name).filter(col("a") == 1)
475+
if session.sql_simplifier_enabled:
476+
assert len(df._plan.children_plan_nodes) == 1
477+
assert isinstance(df._plan.children_plan_nodes[0], SelectableEntity)
478+
selectable = df._plan.children_plan_nodes[0]
479+
# Set _attributes to simulate cached attributes
480+
selectable._attributes = selectable.snowflake_plan.attributes
481+
482+
copied_selectable = copy.deepcopy(selectable)
483+
484+
# Verify attributes were NOT copied
485+
assert copied_selectable._attributes is None

tests/unit/test_deepcopy.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
#
44
import copy
55
import uuid
6+
from typing import List
67
from unittest import mock
78

9+
import pytest
10+
811
from snowflake.snowpark import Session, functions as F, types as T
912
from snowflake.snowpark._internal.analyzer.analyzer_utils import UNION
1013
from snowflake.snowpark._internal.analyzer.select_statement import (
@@ -232,3 +235,135 @@ def test_set_statement(mock_session, mock_analyzer):
232235
verify_copied_selectable(
233236
copied_operand, original_operand, expect_plan_copied=False
234237
)
238+
239+
240+
def init_attributes(node: Selectable) -> List[Attribute]:
241+
"""Initialize _attributes on a Selectable with test data."""
242+
attrs = [Attribute("A", IntegerType()), Attribute("B", StringType())]
243+
node._attributes = attrs
244+
return attrs
245+
246+
247+
def verify_attributes_deepcopied(
248+
copied_selectable: Selectable,
249+
original_selectable: Selectable,
250+
) -> None:
251+
"""Verify that _attributes was deepcopied (new list with new Attribute objects)."""
252+
assert copied_selectable._attributes is not None
253+
assert original_selectable._attributes is not None
254+
# List should be different object (deepcopy)
255+
assert copied_selectable._attributes is not original_selectable._attributes
256+
# Each Attribute object should be different (deepcopy)
257+
for copied_attr, original_attr in zip(
258+
copied_selectable._attributes, original_selectable._attributes
259+
):
260+
assert copied_attr is not original_attr
261+
assert copied_attr.name == original_attr.name
262+
assert copied_attr.datatype == original_attr.datatype
263+
assert copied_attr.nullable == original_attr.nullable
264+
265+
266+
def verify_attributes_shallow_copied(
267+
copied_selectable: Selectable,
268+
original_selectable: Selectable,
269+
) -> None:
270+
"""Verify that _attributes was shallow-copied (new list but same Attribute objects)."""
271+
assert copied_selectable._attributes is not None
272+
assert original_selectable._attributes is not None
273+
# List should be different object (shallow copy creates new list)
274+
assert copied_selectable._attributes is not original_selectable._attributes
275+
# Each Attribute object should be the same (shallow copy)
276+
for copied_attr, original_attr in zip(
277+
copied_selectable._attributes, original_selectable._attributes
278+
):
279+
assert copied_attr is original_attr
280+
281+
282+
def verify_attributes_not_copied(
283+
copied_selectable: Selectable,
284+
) -> None:
285+
"""Verify that _attributes was NOT copied (should be None or default)."""
286+
assert copied_selectable._attributes is None
287+
288+
289+
def _create_selectable_entity(mock_session, mock_analyzer):
290+
"""Helper to create a SelectableEntity for testing."""
291+
return SelectableEntity(
292+
SnowflakeTable("TEST_TABLE", session=mock_session), analyzer=mock_analyzer
293+
)
294+
295+
296+
def _create_select_statement(mock_session, mock_analyzer):
297+
"""Helper to create a SelectStatement for testing."""
298+
from_ = SelectableEntity(
299+
SnowflakeTable("TEST_TABLE", session=mock_session), analyzer=mock_analyzer
300+
)
301+
return SelectStatement(from_=from_, analyzer=mock_analyzer)
302+
303+
304+
@pytest.mark.parametrize(
305+
"selectable_factory,copy_func,reduce_describe_enabled,cte_enabled",
306+
[
307+
# SelectableEntity with deepcopy - flags enabled (should copy)
308+
(_create_selectable_entity, copy.deepcopy, True, True),
309+
# SelectableEntity with deepcopy - flags disabled (should NOT copy)
310+
(_create_selectable_entity, copy.deepcopy, False, False),
311+
(_create_selectable_entity, copy.deepcopy, True, False),
312+
(_create_selectable_entity, copy.deepcopy, False, True),
313+
# SelectStatement with copy (shallow) - flags enabled (should copy)
314+
(_create_select_statement, copy.copy, True, True),
315+
# SelectStatement with copy (shallow) - flags disabled (should NOT copy)
316+
(_create_select_statement, copy.copy, False, False),
317+
(_create_select_statement, copy.copy, True, False),
318+
(_create_select_statement, copy.copy, False, True),
319+
# SelectStatement with deepcopy - flags enabled (should copy)
320+
(_create_select_statement, copy.deepcopy, True, True),
321+
# SelectStatement with deepcopy - flags disabled (should NOT copy)
322+
(_create_select_statement, copy.deepcopy, False, False),
323+
(_create_select_statement, copy.deepcopy, True, False),
324+
(_create_select_statement, copy.deepcopy, False, True),
325+
],
326+
)
327+
def test_attributes_copy_with_session_flags(
328+
mock_session,
329+
mock_analyzer,
330+
selectable_factory,
331+
copy_func,
332+
reduce_describe_enabled,
333+
cte_enabled,
334+
):
335+
"""Test _attributes copy behavior based on session flags and copy method.
336+
337+
When both reduce_describe_query_enabled and cte_optimization_enabled are True,
338+
_attributes should be copied. Otherwise, _attributes should NOT be copied.
339+
340+
Note: SelectableEntity only has __deepcopy__ (no __copy__), so only deepcopy is tested.
341+
SelectStatement has both __copy__ and __deepcopy__.
342+
"""
343+
# Set session flags
344+
mock_session.reduce_describe_query_enabled = reduce_describe_enabled
345+
mock_session.cte_optimization_enabled = cte_enabled
346+
347+
# Create the selectable and initialize fields
348+
selectable = selectable_factory(mock_session, mock_analyzer)
349+
init_selectable_fields(selectable, init_plan=False)
350+
init_attributes(selectable)
351+
352+
# Perform the copy
353+
copied_selectable = copy_func(selectable)
354+
355+
# Verify based on flags
356+
both_flags_enabled = reduce_describe_enabled and cte_enabled
357+
if both_flags_enabled:
358+
# _attributes should be copied
359+
assert copied_selectable._attributes is not None
360+
assert copied_selectable._attributes is not selectable._attributes
361+
if copy_func == copy.deepcopy:
362+
# Deep copy: new Attribute objects
363+
verify_attributes_deepcopied(copied_selectable, selectable)
364+
else:
365+
# Shallow copy: same Attribute objects
366+
verify_attributes_shallow_copied(copied_selectable, selectable)
367+
else:
368+
# _attributes should NOT be copied
369+
verify_attributes_not_copied(copied_selectable)

0 commit comments

Comments
 (0)