Skip to content

Commit 47ebc94

Browse files
committed
lint
1 parent 751e50c commit 47ebc94

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,13 @@ class ApplyFunc:
464464
def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
465465
# First column is row position, extract it for later use
466466
row_positions = df.iloc[:, 0]
467-
467+
468468
# If we have index columns, set them as the index
469469
if num_index_columns > 0:
470470
# Columns after row position are index columns, then data columns
471-
index_cols = df.iloc[:, 1:1+num_index_columns]
472-
data_cols = df.iloc[:, 1+num_index_columns:]
473-
471+
index_cols = df.iloc[:, 1 : 1 + num_index_columns]
472+
data_cols = df.iloc[:, 1 + num_index_columns :]
473+
474474
# Set the index using the index columns
475475
if num_index_columns == 1:
476476
index = index_cols.iloc[:, 0]
@@ -480,15 +480,17 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
480480
# Multi-index case
481481
index = native_pd.MultiIndex.from_arrays(
482482
[index_cols.iloc[:, i] for i in range(num_index_columns)],
483-
names=index_column_pandas_labels if index_column_pandas_labels else None
483+
names=index_column_pandas_labels
484+
if index_column_pandas_labels
485+
else None,
484486
)
485487
data_cols.index = index
486488
df = data_cols
487489
else:
488490
# No index columns, use row position as index (original behavior)
489491
df = df.iloc[:, 1:]
490492
df.index = row_positions
491-
493+
492494
df.columns = column_index
493495
df = df.apply(
494496
func, axis=1, raw=raw, result_type=result_type, args=args, **kwargs
@@ -523,10 +525,10 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
523525
# - VALUE contains the result at this position.
524526
if isinstance(df, native_pd.DataFrame):
525527
result = []
526-
for idx, (row_position_index, series) in enumerate(df.iterrows()):
528+
for idx, (_row_position_index, series) in enumerate(df.iterrows()):
527529
# Use the actual row position from row_positions, not the index value
528530
actual_row_position = row_positions.iloc[idx]
529-
531+
530532
for i, (label, value) in enumerate(series.items()):
531533
# If this is a tuple then we store each component with a 0-based
532534
# lookup. For example, (a,b,c) is stored as (0:a, 1:b, 2:c).

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9602,7 +9602,9 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1(
96029602
# The apply function is encapsulated in a UDTF and run as a stored procedure on the pandas dataframe.
96039603
# Determine if we should pass index columns to the UDTF
96049604
# We pass index columns when the index is not the row position itself
9605-
index_columns_for_udtf = new_internal_df.index_column_snowflake_quoted_identifiers
9605+
index_columns_for_udtf = (
9606+
new_internal_df.index_column_snowflake_quoted_identifiers
9607+
)
96069608
if row_position_snowflake_quoted_identifier in index_columns_for_udtf:
96079609
# The row position IS the index (e.g., RangeIndex), don't pass index columns
96089610
index_columns_for_udtf = []
@@ -9611,8 +9613,10 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1(
96119613
else:
96129614
# Pass the actual index columns to the UDTF
96139615
num_index_columns = len(index_columns_for_udtf)
9614-
index_column_pandas_labels_for_udtf = new_internal_df.index_column_pandas_labels
9615-
9616+
index_column_pandas_labels_for_udtf = (
9617+
new_internal_df.index_column_pandas_labels
9618+
)
9619+
96169620
func_udtf = create_udtf_for_apply_axis_1(
96179621
row_position_snowflake_quoted_identifier,
96189622
func,
@@ -10386,7 +10390,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no
1038610390
column_index = try_convert_index_to_native(
1038710391
self._modin_frame.data_columns_index
1038810392
)
10389-
10393+
1039010394
# get input types of index and data columns from the dataframe
1039110395
data_input_types = self._modin_frame.get_snowflake_type(
1039210396
self._modin_frame.data_column_snowflake_quoted_identifiers

tests/integ/modin/frame/test_apply.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,38 +147,35 @@ def foo(row) -> str:
147147
with SqlCounter(query_count=4, join_count=0, udtf_count=0):
148148
eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=1))
149149

150-
@pytest.mark.parametrize("index", [
151-
None,
152-
['a', 'b'],
153-
[100, 200]
154-
])
150+
151+
@pytest.mark.parametrize("index", [None, ["a", "b"], [100, 200]])
155152
@sql_count_checker(query_count=5, join_count=2, udtf_count=1)
156153
def test_apply_axis_1_index_preservation(index):
157154
"""Test that apply(axis=1) preserves index values correctly."""
158155
# Test with default RangeIndex
159156
native_df = native_pd.DataFrame([[1, 2], [3, 4]], index=index)
160157
snow_df = pd.DataFrame(native_df)
161-
158+
162159
eval_snowpark_pandas_result(
163-
snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1)
160+
snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1)
164161
)
165162

166163

167164
@sql_count_checker(query_count=5, join_count=2, udtf_count=1)
168165
def test_apply_axis_1_multiindex_preservation():
169166
"""Test that apply(axis=1) preserves MultiIndex values correctly."""
170167
# Test with MultiIndex
171-
multi_index = pd.MultiIndex.from_tuples([('A', 1), ('B', 2), ('C', 3)], names=['letter', 'number'])
168+
multi_index = pd.MultiIndex.from_tuples(
169+
[("A", 1), ("B", 2), ("C", 3)], names=["letter", "number"]
170+
)
172171
native_df = native_pd.DataFrame([[1, 2], [3, 4], [5, 6]], index=multi_index)
173172
snow_df = pd.DataFrame(native_df)
174-
175-
173+
176174
eval_snowpark_pandas_result(
177-
snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1)
175+
snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1)
178176
)
179177

180178

181-
182179
@pytest.mark.xfail(strict=True, raises=NotImplementedError)
183180
@sql_count_checker(query_count=0)
184181
def test_frame_with_timedelta_index():

0 commit comments

Comments
 (0)