Skip to content

Commit 6a491f3

Browse files
committed
Preserve ArrowDtype in struct.fields
1 parent 18a6483 commit 6a491f3

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

python/cudf/cudf/core/accessors/struct.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from cudf.core.column.struct import StructColumn
1010
from cudf.core.dtype.validators import is_dtype_obj_struct
1111
from cudf.core.dtypes import StructDtype
12-
from cudf.utils.dtypes import get_dtype_of_same_kind
1312

1413
if TYPE_CHECKING:
1514
from cudf.core.dataframe import DataFrame
@@ -64,26 +63,18 @@ def field(self, key) -> Series | Index:
6463
field_keys = list(struct_dtype_fields.keys())
6564
if key in struct_dtype_fields:
6665
pos = field_keys.index(key)
67-
assert isinstance(self._column, StructColumn)
68-
return self._return_or_inplace(
69-
self._column._get_sliced_child(pos)._with_type_metadata(
70-
get_dtype_of_same_kind(
71-
self._column.dtype, struct_dtype_fields[key]
72-
)
73-
)
74-
)
7566
elif isinstance(key, int):
76-
try:
77-
assert isinstance(self._column, StructColumn)
78-
return self._return_or_inplace(
79-
self._column._get_sliced_child(key)
80-
)
81-
except IndexError as err:
82-
raise IndexError(f"Index {key} out of range") from err
67+
pos = key
8368
else:
8469
raise KeyError(
8570
f"Field '{key}' is not found in the set of existing keys."
8671
)
72+
assert isinstance(self._column, StructColumn)
73+
try:
74+
result = self._column._get_sliced_child(pos)
75+
except IndexError as err:
76+
raise IndexError(f"Index {key} out of range") from err
77+
return self._return_or_inplace(result)
8778

8879
def explode(self) -> DataFrame:
8980
"""

python/cudf/cudf/core/column/struct.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, cast
5+
from typing import TYPE_CHECKING, Any
66

77
import pandas as pd
88
import pyarrow as pa
@@ -13,7 +13,10 @@
1313
from cudf.core.column.column import ColumnBase
1414
from cudf.core.dtype.validators import is_dtype_obj_struct
1515
from cudf.core.dtypes import StructDtype
16-
from cudf.utils.dtypes import dtype_from_pylibcudf_column
16+
from cudf.utils.dtypes import (
17+
dtype_from_pylibcudf_column,
18+
get_dtype_of_same_kind,
19+
)
1720
from cudf.utils.scalar import (
1821
maybe_nested_pa_scalar_to_py,
1922
pa_scalar_to_plc_scalar,
@@ -75,16 +78,31 @@ def _validate_args( # type: ignore[override]
7578
return plc_column, dtype
7679

7780
def _get_sliced_child(self, idx: int) -> ColumnBase:
78-
"""Get a child column properly sliced to match the parent's view."""
81+
"""
82+
Get a child column properly sliced to match the parent's view.
83+
84+
Parameters
85+
----------
86+
idx : int
87+
The positional index of the child column to get.
88+
89+
Returns
90+
-------
91+
ColumnBase
92+
The child column at positional index `idx`.
93+
"""
7994
if idx < 0 or idx >= self.plc_column.num_children():
8095
raise IndexError(
8196
f"Index {idx} out of range for {self.plc_column.num_children()} children"
8297
)
8398

8499
sliced_plc_col = self.plc_column.struct_view().get_sliced_child(idx)
85-
dtype = cast(StructDtype, self.dtype)
86-
sub_dtype = list(dtype.fields.values())[idx]
87-
return ColumnBase.create(sliced_plc_col, sub_dtype)
100+
sub_dtype = list(
101+
StructDtype.from_struct_dtype(self.dtype).fields.values()
102+
)[idx]
103+
return ColumnBase.create(
104+
sliced_plc_col, get_dtype_of_same_kind(self.dtype, sub_dtype)
105+
)
88106

89107
def _prep_pandas_compat_repr(self) -> StringColumn | Self:
90108
"""

python/cudf/cudf/core/dataframe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2092,10 +2092,7 @@ def _concat(
20922092
out.index.dtype, CategoricalDtype
20932093
):
20942094
out = out.set_index(out.index)
2095-
for name, col in out._column_labels_and_values:
2096-
out._data[name] = col._with_type_metadata(
2097-
tables[0]._data[name].dtype,
2098-
)
2095+
out = out._copy_type_metadata(tables[0])
20992096

21002097
# Reassign index and column names
21012098
if objs[0]._data.multiindex:

0 commit comments

Comments
 (0)