Skip to content

Commit 70a209a

Browse files
committed
update xfail set and pass more xarray tests
1 parent 5ffe816 commit 70a209a

File tree

2 files changed

+14
-217
lines changed

2 files changed

+14
-217
lines changed

python/cudf/cudf/pandas/_wrappers/pandas.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# with this module https://github.com/rapidsai/cudf/issues/14521#issue-2015198786
1919
import pyarrow.dataset as ds # noqa: F401
2020
from pandas._testing import at, getitem, iat, iloc, loc, setitem
21+
from pandas.compat._optional import import_optional_dependency
2122
from pandas.tseries.holiday import (
2223
AbstractHolidayCalendar as pd_AbstractHolidayCalendar,
2324
EasterMonday as pd_EasterMonday,
@@ -316,13 +317,17 @@ def _DataFrame_columns(self):
316317

317318

318319
def _to_xarray(self):
319-
# Call to_xarray directly on the slow object, not via _FastSlowAttribute.
320-
# This keeps the module accelerator enabled so pandas returns proxy
321-
# Index/ExtensionArray objects that inherit from ExtensionArray and pass
322-
# xarray's isinstance checks. If those checks fail, xarray falls back to
323-
# np.issubdtype(array.dtype, ...), which raises TypeError for unsupported
324-
# extension dtypes.
325-
return self._fsproxy_slow.to_xarray()
320+
# Call xarray conversion functions directly with self (the proxy object).
321+
# We must pass the proxy (self), not the slow pandas object, because xarray
322+
# does isinstance checks against pd.MultiIndex and pd.api.extensions.ExtensionArray.
323+
# After cudf.pandas.install(), these refer to proxy classes. The slow object
324+
# contains real pandas types that don't pass isinstance checks against the proxy
325+
# classes.
326+
xr = import_optional_dependency("xarray")
327+
if self.ndim == 1:
328+
return xr.DataArray.from_series(self)
329+
else:
330+
return xr.Dataset.from_dataframe(self)
326331

327332

328333
DataFrame = make_final_proxy_type(
@@ -821,6 +826,7 @@ def Index__setattr__(self, name, value):
821826
"_mask": _FastSlowAttribute("_mask", private=True),
822827
"__array__": _FastSlowAttribute("__array__"),
823828
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
829+
"__arrow_array__": _FastSlowAttribute("__arrow_array__"),
824830
},
825831
)
826832

@@ -864,6 +870,7 @@ def Index__setattr__(self, name, value):
864870
"__abs__": _FastSlowAttribute("__abs__"),
865871
"__contains__": _FastSlowAttribute("__contains__"),
866872
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
873+
"__arrow_array__": _FastSlowAttribute("__arrow_array__"),
867874
},
868875
)
869876

0 commit comments

Comments
 (0)