Skip to content

Commit 5ffe816

Browse files
committed
Call to_xarray on the slow object
1 parent a32b8cf commit 5ffe816

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

python/cudf/cudf/pandas/_wrappers/pandas.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,16 @@ def _DataFrame_columns(self):
315315
return result
316316

317317

318+
def _to_xarray(self):
319+
# Call to_xarray directly on the slow object, not via _FastSlowAttribute.
320+
# This keeps the module accelerator enabled so pandas returns proxy
321+
# Index/ExtensionArray objects that inherit from ExtensionArray and pass
322+
# xarray's isinstance checks. If those checks fail, xarray falls back to
323+
# np.issubdtype(array.dtype, ...), which raises TypeError for unsupported
324+
# extension dtypes.
325+
return self._fsproxy_slow.to_xarray()
326+
327+
318328
DataFrame = make_final_proxy_type(
319329
"DataFrame",
320330
cudf.DataFrame,
@@ -347,6 +357,7 @@ def _DataFrame_columns(self):
347357
"flags": _FastSlowAttribute("flags", private=True),
348358
"memory_usage": _FastSlowAttribute("memory_usage"),
349359
"__sizeof__": _FastSlowAttribute("__sizeof__"),
360+
"to_xarray": _to_xarray,
350361
},
351362
)
352363

@@ -420,6 +431,7 @@ def _argsort(self, *args, **kwargs):
420431
"_accessors": set(),
421432
"dtype": property(_Series_dtype),
422433
"argsort": _argsort,
434+
"to_xarray": _to_xarray,
423435
"attrs": _FastSlowAttribute("attrs"),
424436
"_mgr": _FastSlowAttribute("_mgr", private=True),
425437
"array": _FastSlowAttribute("array", private=True),
@@ -576,6 +588,7 @@ def Index__setattr__(self, name, value):
576588
pd.Categorical,
577589
fast_to_slow=_Unusable(),
578590
slow_to_fast=_Unusable(),
591+
bases=(pd.api.extensions.ExtensionArray,),
579592
)
580593

581594
CategoricalDtype = make_final_proxy_type(
@@ -615,6 +628,7 @@ def Index__setattr__(self, name, value):
615628
pd.arrays.DatetimeArray,
616629
fast_to_slow=_Unusable(),
617630
slow_to_fast=_Unusable(),
631+
bases=(pd.api.extensions.ExtensionArray,),
618632
additional_attributes={
619633
"_data": _FastSlowAttribute("_data", private=True),
620634
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -718,6 +732,7 @@ def Index__setattr__(self, name, value):
718732
pd.arrays.PeriodArray,
719733
fast_to_slow=_Unusable(),
720734
slow_to_fast=_Unusable(),
735+
bases=(pd.api.extensions.ExtensionArray,),
721736
additional_attributes={
722737
"_data": _FastSlowAttribute("_data", private=True),
723738
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -800,6 +815,7 @@ def Index__setattr__(self, name, value):
800815
pd.arrays.StringArray,
801816
fast_to_slow=_Unusable(),
802817
slow_to_fast=_Unusable(),
818+
bases=(pd.api.extensions.ExtensionArray,),
803819
additional_attributes={
804820
"_data": _FastSlowAttribute("_data", private=True),
805821
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -838,6 +854,7 @@ def Index__setattr__(self, name, value):
838854
pd.core.arrays.string_arrow.ArrowStringArray,
839855
fast_to_slow=_Unusable(),
840856
slow_to_fast=_Unusable(),
857+
bases=(pd.api.extensions.ExtensionArray,),
841858
additional_attributes={
842859
"_pa_array": _FastSlowAttribute("_pa_array", private=True),
843860
"__array__": _FastSlowAttribute("__array__", private=True),
@@ -877,6 +894,7 @@ def Index__setattr__(self, name, value):
877894
pd.arrays.BooleanArray,
878895
fast_to_slow=_Unusable(),
879896
slow_to_fast=_Unusable(),
897+
bases=(pd.api.extensions.ExtensionArray,),
880898
additional_attributes={
881899
"_data": _FastSlowAttribute("_data", private=True),
882900
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -901,6 +919,7 @@ def Index__setattr__(self, name, value):
901919
pd.arrays.IntegerArray,
902920
fast_to_slow=_Unusable(),
903921
slow_to_fast=_Unusable(),
922+
bases=(pd.api.extensions.ExtensionArray,),
904923
additional_attributes={
905924
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
906925
"_data": _FastSlowAttribute("_data", private=True),
@@ -1021,6 +1040,7 @@ def Index__setattr__(self, name, value):
10211040
pd.arrays.IntervalArray,
10221041
fast_to_slow=_Unusable(),
10231042
slow_to_fast=_Unusable(),
1043+
bases=(pd.api.extensions.ExtensionArray,),
10241044
additional_attributes={
10251045
"_data": _FastSlowAttribute("_data", private=True),
10261046
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -1055,6 +1075,7 @@ def Index__setattr__(self, name, value):
10551075
pd.arrays.FloatingArray,
10561076
fast_to_slow=_Unusable(),
10571077
slow_to_fast=_Unusable(),
1078+
bases=(pd.api.extensions.ExtensionArray,),
10581079
additional_attributes={
10591080
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
10601081
"_data": _FastSlowAttribute("_data", private=True),

python/cudf/cudf/pandas/scripts/conftest-patch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3421,8 +3421,6 @@ def set_copy_on_write_option():
34213421
"tests/generic/test_to_xarray.py::TestDataFrameToXArray::test_to_xarray_index_types[uint32]",
34223422
"tests/generic/test_to_xarray.py::TestDataFrameToXArray::test_to_xarray_index_types[uint64]",
34233423
"tests/generic/test_to_xarray.py::TestDataFrameToXArray::test_to_xarray_index_types[uint8]",
3424-
"tests/generic/test_to_xarray.py::TestSeriesToXArray::test_to_xarray_index_types[string-pyarrow]",
3425-
"tests/generic/test_to_xarray.py::TestSeriesToXArray::test_to_xarray_index_types[string-python]",
34263424
"tests/groupby/aggregate/test_aggregate.py::test_agg_multiple_with_as_index_false_subset_to_a_single_column",
34273425
"tests/groupby/aggregate/test_aggregate.py::test_agg_str_with_kwarg_axis_1_raises[count]",
34283426
"tests/groupby/aggregate/test_aggregate.py::test_agg_str_with_kwarg_axis_1_raises[first]",

0 commit comments

Comments
 (0)