diff --git a/python/cudf/cudf/core/accessors/struct.py b/python/cudf/cudf/core/accessors/struct.py index 3a6f6b6468c..ea40781fe5b 100644 --- a/python/cudf/cudf/core/accessors/struct.py +++ b/python/cudf/cudf/core/accessors/struct.py @@ -9,7 +9,6 @@ from cudf.core.column.struct import StructColumn from cudf.core.dtype.validators import is_dtype_obj_struct from cudf.core.dtypes import StructDtype -from cudf.utils.dtypes import get_dtype_of_same_kind if TYPE_CHECKING: from cudf.core.dataframe import DataFrame @@ -64,26 +63,18 @@ def field(self, key) -> Series | Index: field_keys = list(struct_dtype_fields.keys()) if key in struct_dtype_fields: pos = field_keys.index(key) - assert isinstance(self._column, StructColumn) - return self._return_or_inplace( - self._column._get_sliced_child(pos)._with_type_metadata( - get_dtype_of_same_kind( - self._column.dtype, struct_dtype_fields[key] - ) - ) - ) elif isinstance(key, int): - try: - assert isinstance(self._column, StructColumn) - return self._return_or_inplace( - self._column._get_sliced_child(key) - ) - except IndexError as err: - raise IndexError(f"Index {key} out of range") from err + pos = key else: raise KeyError( f"Field '{key}' is not found in the set of existing keys." ) + assert isinstance(self._column, StructColumn) + try: + result = self._column._get_sliced_child(pos) + except IndexError as err: + raise IndexError(f"Index {key} out of range") from err + return self._return_or_inplace(result) def explode(self) -> DataFrame: """ diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index 7f6bc99a552..37ed6f75cc4 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import pandas as pd import pyarrow as pa @@ -13,7 +13,10 @@ from cudf.core.column.column import ColumnBase from cudf.core.dtype.validators import is_dtype_obj_struct from cudf.core.dtypes import StructDtype -from cudf.utils.dtypes import dtype_from_pylibcudf_column +from cudf.utils.dtypes import ( + dtype_from_pylibcudf_column, + get_dtype_of_same_kind, +) from cudf.utils.scalar import ( maybe_nested_pa_scalar_to_py, pa_scalar_to_plc_scalar, @@ -75,16 +78,31 @@ def _validate_args( # type: ignore[override] return plc_column, dtype def _get_sliced_child(self, idx: int) -> ColumnBase: - """Get a child column properly sliced to match the parent's view.""" + """ + Get a child column properly sliced to match the parent's view. + + Parameters + ---------- + idx : int + The positional index of the child column to get. + + Returns + ------- + ColumnBase + The child column at positional index `idx`. + """ if idx < 0 or idx >= self.plc_column.num_children(): raise IndexError( f"Index {idx} out of range for {self.plc_column.num_children()} children" ) sliced_plc_col = self.plc_column.struct_view().get_sliced_child(idx) - dtype = cast(StructDtype, self.dtype) - sub_dtype = list(dtype.fields.values())[idx] - return ColumnBase.create(sliced_plc_col, sub_dtype) + sub_dtype = list( + StructDtype.from_struct_dtype(self.dtype).fields.values() + )[idx] + return ColumnBase.create( + sliced_plc_col, get_dtype_of_same_kind(self.dtype, sub_dtype) + ) def _prep_pandas_compat_repr(self) -> StringColumn | Self: """ diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index fe9e82bd056..3f9945019d3 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -2092,10 +2092,7 @@ def _concat( out.index.dtype, CategoricalDtype ): out = out.set_index(out.index) - for name, col in out._column_labels_and_values: - out._data[name] = col._with_type_metadata( - tables[0]._data[name].dtype, - ) + out = out._copy_type_metadata(tables[0]) # Reassign index and column names if objs[0]._data.multiindex: