@@ -152,7 +152,6 @@ def foo(row) -> str:
152152@sql_count_checker (query_count = 5 , join_count = 2 , udtf_count = 1 )
153153def test_apply_axis_1_index_preservation (index ):
154154 """Test that apply(axis=1) preserves index values correctly."""
155- # Test with default RangeIndex
156155 native_df = native_pd .DataFrame ([[1 , 2 ], [3 , 4 ]], index = index )
157156 snow_df = pd .DataFrame (native_df )
158157
@@ -161,6 +160,21 @@ def test_apply_axis_1_index_preservation(index):
161160 )
162161
163162
163+ @sql_count_checker (query_count = 5 , join_count = 2 , udtf_count = 1 )
164+ def test_apply_axis_1_index_from_col ():
165+ """Test that apply(axis=1) preserves an index when set from a column"""
166+ native_df = native_pd .DataFrame (
167+ [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], columns = ["a" , "b" , "c" ]
168+ )
169+ snow_df = pd .DataFrame (native_df )
170+ snow_df = snow_df .set_index ("a" )
171+ native_df = native_df .set_index ("a" )
172+
173+ eval_snowpark_pandas_result (
174+ snow_df , native_df , lambda x : x .apply (lambda row : row .name , axis = 1 )
175+ )
176+
177+
164178@sql_count_checker (query_count = 5 , join_count = 2 , udtf_count = 1 )
165179def test_apply_axis_1_multiindex_preservation ():
166180 """Test that apply(axis=1) preserves MultiIndex values correctly."""
0 commit comments