Skip to content

Commit b1c9bff

Browse files
SNOW-2442822: Fix np.asarray(datetime_with_timezone_index) (#3963)
The current implementation of `Index.__array__` passes `self.dtype` to `self.to_pandas().__array__()`, but we should instead propagate the `dtype` argument from the `__array__` call. Signed-off-by: sfc-gh-mvashishtha <[email protected]>
1 parent 7412bf4 commit b1c9bff

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#### Bug Fixes
4242

4343
- Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns.
44+
- Fixed a bug where converting a modin datetime index with a timezone to a numpy array with `np.asarray` would cause a `TypeError`.
4445

4546
#### Improvements
4647

src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def get_valid_col_pos_list_from_columns(
962962

963963
# convert float like keys to integers
964964
elif not is_integer_dtype(pos_array.dtype):
965-
assert is_float_dtype(
965+
assert pos_array.size == 0 or is_float_dtype(
966966
pos_array.dtype
967967
), "list-like key must be list of int or float"
968968
pos_list = pos_array.astype(int)

src/snowflake/snowpark/modin/plugin/extensions/index.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,10 +917,6 @@ def _summary(self, name: Any = None) -> str:
917917

918918
@materialization_warning
919919
def __array__(self, dtype: Any = None) -> np.ndarray:
920-
# Ensure that the existing index dtype is preserved in the returned array
921-
# if no other dtype is given.
922-
if dtype is None:
923-
dtype = self.dtype
924920
return self.to_pandas().__array__(dtype=dtype)
925921

926922
def __repr__(self) -> str:

tests/integ/modin/test_to_numpy.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import modin.pandas as pd
99
import numpy as np
1010
import pandas as native_pd
11+
from pandas.api.types import is_datetime64tz_dtype
1112
import pytest
1213
from numpy.testing import assert_array_equal
14+
from pytest import param
1315

1416
import snowflake.snowpark.modin.plugin # noqa: F401
1517
from snowflake.snowpark._internal.utils import (
@@ -21,6 +23,10 @@
2123
from tests.utils import Utils
2224

2325

26+
def values_property_getter(object) -> np.ndarray:
27+
return object.values
28+
29+
2430
@pytest.mark.parametrize(
2531
"data",
2632
[
@@ -40,11 +46,19 @@
4046
[datetime.time(1, 2, 3, 1), datetime.time(0, 0, 0), None],
4147
[datetime.datetime(2023, 1, 1), datetime.datetime(2023, 1, 1, 1, 2, 3)],
4248
[datetime.datetime(2023, 1, 1), datetime.datetime(2023, 1, 1, 1, 2, 3), None],
49+
[pd.Timestamp(1, tz="UTC")],
4350
],
4451
)
52+
@pytest.mark.parametrize(
53+
"to_numpy",
54+
(
55+
param(lambda index: np.asarray(index), id="asarray"),
56+
param(lambda index: index.to_numpy(), id="to_numpy"),
57+
param(values_property_getter, id="values"),
58+
),
59+
)
4560
@pytest.mark.parametrize("pandas_obj", ["DataFrame", "Series", "Index"])
46-
@pytest.mark.parametrize("func", ["to_numpy", "values"])
47-
def test_to_numpy_basic(data, pandas_obj, func):
61+
def test_to_numpy_basic(data, pandas_obj, to_numpy, request):
4862
if pandas_obj == "Series":
4963
df = pd.Series(data)
5064
native_df = native_pd.Series(data)
@@ -54,16 +68,21 @@ def test_to_numpy_basic(data, pandas_obj, func):
5468
else:
5569
df = pd.DataFrame([data, data])
5670
native_df = native_pd.DataFrame([data, data])
57-
with SqlCounter(query_count=1):
58-
if func == "to_numpy":
59-
assert_array_equal(df.to_numpy(), native_df.to_numpy())
60-
else:
61-
assert_array_equal(df.values, native_df.values)
71+
72+
with SqlCounter(
73+
# modin_datetime_series_with_timezone.values internally calls .dtype,
74+
# which triggers an extra groupby query to determine the dtype.
75+
query_count=2
76+
if to_numpy is values_property_getter
77+
and pandas_obj == "Series"
78+
and is_datetime64tz_dtype(native_df)
79+
else 1
80+
):
81+
snow_result = to_numpy(df)
82+
native_result = to_numpy(native_df)
83+
assert_array_equal(snow_result, native_result)
6284
if pandas_obj == "Series":
63-
with SqlCounter(query_count=1):
64-
res = df.to_numpy()
65-
expected_res = native_df.to_numpy()
66-
for r1, r2 in zip(res, expected_res):
85+
for r1, r2 in zip(snow_result, native_result):
6786
# native pandas series returns a list of pandas Timestamp,
6887
# but Snowpark pandas returns a list of integers in ms.
6988
# Their values are equal

0 commit comments

Comments
 (0)