Skip to content

Commit 9c497ac

Browse files
committed
fix query counts and test failures
1 parent a7cf24c commit 9c497ac

File tree

5 files changed

+19
-26
lines changed

5 files changed

+19
-26
lines changed

src/snowflake/snowpark/modin/plugin/_internal/frame.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,8 @@ def update_snowflake_quoted_identifiers_with_expressions(
11441144
self,
11451145
quoted_identifier_to_column_map: dict[str, SnowparkColumn],
11461146
snowpark_pandas_types: Optional[list[Optional[SnowparkPandasType]]] = None,
1147+
*,
1148+
new_index_column_pandas_labels: Optional[list[Hashable]] = None,
11471149
) -> UpdatedInternalFrameResult:
11481150
"""
11491151
Points Snowflake quoted identifiers to column expression given by `quoted_identifier_to_column_map`.
@@ -1171,6 +1173,8 @@ def update_snowflake_quoted_identifiers_with_expressions(
11711173
must be index columns and data columns in the original internal frame.
11721174
data_column_snowpark_pandas_types: The optional Snowpark pandas types for the new
11731175
expressions, in the order of the keys of quoted_identifier_to_column_map.
1176+
new_index_column_pandas_labels: The optional list of labels to be used as
1177+
index_column_pandas_labels for the result.
11741178
11751179
Returns:
11761180
UpdatedInternalFrameResult: A tuple containing the new InternalFrame with updated column references, and a mapping
@@ -1252,7 +1256,9 @@ def update_snowflake_quoted_identifiers_with_expressions(
12521256
data_column_pandas_labels=self.data_column_pandas_labels,
12531257
data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers,
12541258
data_column_pandas_index_names=self.data_column_pandas_index_names,
1255-
index_column_pandas_labels=self.index_column_pandas_labels,
1259+
index_column_pandas_labels=self.index_column_pandas_labels
1260+
if new_index_column_pandas_labels is None
1261+
else new_index_column_pandas_labels,
12561262
index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers,
12571263
data_column_types=[
12581264
new_type_mapping[k]

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11888,7 +11888,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler":
1188811888
frame.index_column_snowflake_quoted_identifiers[0]: pandas_lit(
1188911889
frame.data_column_pandas_labels[0]
1189011890
),
11891-
}
11891+
},
11892+
# Swap the name of the index/columns objects
11893+
new_index_column_pandas_labels=frame.data_column_pandas_index_names,
1189211894
)[0]
1189311895
).set_columns([None])
1189411896

tests/integ/modin/frame/test_aggregate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def test_string_sum_with_nulls():
193193
with pytest.raises(TypeError):
194194
pandas_df.sum(numeric_only=False)
195195
snow_result = snow_df.sum(numeric_only=False)
196-
assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"]))
196+
assert_series_equal(
197+
snow_result.to_pandas(), native_pd.Series(["ab"]), check_index_type=False
198+
)
197199

198200

199201
class TestTimedelta:

tests/integ/modin/frame/test_iloc.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,21 +1077,9 @@ def iloc_helper(df):
10771077
else:
10781078
return native_pd.Series([]) if axis == "row" else df.iloc[:, []]
10791079

1080-
def determine_query_count():
1081-
# Multiple queries because of squeeze() - in range is 2, out-of-bounds is 1.
1082-
if axis == "col":
1083-
num_queries = 1
1084-
else:
1085-
if not -8 < key < 7: # key is out of bound
1086-
num_queries = 2
1087-
else:
1088-
num_queries = 1
1089-
return num_queries
1090-
1091-
query_count = determine_query_count()
10921080
# test df with default index
10931081
num_cols = 7
1094-
with SqlCounter(query_count=query_count):
1082+
with SqlCounter(query_count=1):
10951083
eval_snowpark_pandas_result(
10961084
default_index_snowpark_pandas_df,
10971085
default_index_native_df,
@@ -1101,21 +1089,20 @@ def determine_query_count():
11011089

11021090
# test df with non-default index
11031091
num_cols = 6 # set_index() makes the number of columns 6
1104-
with SqlCounter(query_count=query_count):
1092+
with SqlCounter(query_count=1):
11051093
eval_snowpark_pandas_result(
11061094
default_index_snowpark_pandas_df.set_index("D"),
11071095
default_index_native_df.set_index("D"),
11081096
iloc_helper,
11091097
test_attrs=False,
11101098
)
11111099

1112-
query_count = determine_query_count()
11131100
# test df with MultiIndex
11141101
# Index dtype is different between Snowpark and native pandas if key produces empty df.
11151102
num_cols = 7
11161103
native_df = default_index_native_df.set_index(multiindex_native)
11171104
snowpark_df = pd.DataFrame(native_df)
1118-
with SqlCounter(query_count=query_count):
1105+
with SqlCounter(query_count=1):
11191106
eval_snowpark_pandas_result(
11201107
snowpark_df,
11211108
native_df,
@@ -1129,7 +1116,7 @@ def determine_query_count():
11291116
native_df_with_multiindex_columns
11301117
)
11311118
in_range = True if (-8 < key < 7) else False
1132-
with SqlCounter(query_count=query_count):
1119+
with SqlCounter(query_count=1):
11331120
if axis == "row" or in_range: # series result
11341121
eval_snowpark_pandas_result(
11351122
snowpark_df_with_multiindex_columns,
@@ -1151,7 +1138,7 @@ def determine_query_count():
11511138
# test df with MultiIndex on both index and columns
11521139
native_df = native_df_with_multiindex_columns.set_index(multiindex_native)
11531140
snowpark_df = pd.DataFrame(native_df)
1154-
with SqlCounter(query_count=query_count):
1141+
with SqlCounter(query_count=1):
11551142
if axis == "row" or in_range: # series result
11561143
eval_snowpark_pandas_result(
11571144
snowpark_df,

tests/integ/modin/frame/test_squeeze.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ def test_n_by_1(axis, dtype):
3131

3232
@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"])
3333
def test_1_by_n(axis, dtype):
34-
if axis is None:
35-
expected_query_count = 2
36-
else:
37-
expected_query_count = 1
38-
with SqlCounter(query_count=expected_query_count):
34+
with SqlCounter(query_count=1):
3935
eval_snowpark_pandas_result(
4036
*create_test_dfs({"a": [1], "b": [2], "c": [3]}, dtype=dtype),
4137
lambda df: df.squeeze(axis=axis),

0 commit comments

Comments
 (0)