Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions python/cudf/cudf/core/accessors/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from cudf.core.column.struct import StructColumn
from cudf.core.dtype.validators import is_dtype_obj_struct
from cudf.core.dtypes import StructDtype
from cudf.utils.dtypes import get_dtype_of_same_kind

if TYPE_CHECKING:
from cudf.core.dataframe import DataFrame
Expand Down Expand Up @@ -64,26 +63,18 @@ def field(self, key) -> Series | Index:
field_keys = list(struct_dtype_fields.keys())
if key in struct_dtype_fields:
pos = field_keys.index(key)
assert isinstance(self._column, StructColumn)
return self._return_or_inplace(
self._column._get_sliced_child(pos)._with_type_metadata(
get_dtype_of_same_kind(
self._column.dtype, struct_dtype_fields[key]
)
)
)
elif isinstance(key, int):
try:
assert isinstance(self._column, StructColumn)
return self._return_or_inplace(
self._column._get_sliced_child(key)
)
except IndexError as err:
raise IndexError(f"Index {key} out of range") from err
pos = key
else:
raise KeyError(
f"Field '{key}' is not found in the set of existing keys."
)
assert isinstance(self._column, StructColumn)
try:
result = self._column._get_sliced_child(pos)
except IndexError as err:
raise IndexError(f"Index {key} out of range") from err
return self._return_or_inplace(result)

def explode(self) -> DataFrame:
"""
Expand Down
30 changes: 24 additions & 6 deletions python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

import pandas as pd
import pyarrow as pa
Expand All @@ -13,7 +13,10 @@
from cudf.core.column.column import ColumnBase
from cudf.core.dtype.validators import is_dtype_obj_struct
from cudf.core.dtypes import StructDtype
from cudf.utils.dtypes import dtype_from_pylibcudf_column
from cudf.utils.dtypes import (
dtype_from_pylibcudf_column,
get_dtype_of_same_kind,
)
from cudf.utils.scalar import (
maybe_nested_pa_scalar_to_py,
pa_scalar_to_plc_scalar,
Expand Down Expand Up @@ -75,16 +78,31 @@ def _validate_args( # type: ignore[override]
return plc_column, dtype

def _get_sliced_child(self, idx: int) -> ColumnBase:
"""Get a child column properly sliced to match the parent's view."""
"""
Get a child column properly sliced to match the parent's view.

Parameters
----------
idx : int
The positional index of the child column to get.

Returns
-------
ColumnBase
The child column at positional index `idx`.
"""
if idx < 0 or idx >= self.plc_column.num_children():
raise IndexError(
f"Index {idx} out of range for {self.plc_column.num_children()} children"
)

sliced_plc_col = self.plc_column.struct_view().get_sliced_child(idx)
dtype = cast(StructDtype, self.dtype)
sub_dtype = list(dtype.fields.values())[idx]
return ColumnBase.create(sliced_plc_col, sub_dtype)
sub_dtype = list(
StructDtype.from_struct_dtype(self.dtype).fields.values()
)[idx]
return ColumnBase.create(
sliced_plc_col, get_dtype_of_same_kind(self.dtype, sub_dtype)
)

def _prep_pandas_compat_repr(self) -> StringColumn | Self:
"""
Expand Down
5 changes: 1 addition & 4 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2092,10 +2092,7 @@ def _concat(
out.index.dtype, CategoricalDtype
):
out = out.set_index(out.index)
for name, col in out._column_labels_and_values:
out._data[name] = col._with_type_metadata(
tables[0]._data[name].dtype,
)
out = out._copy_type_metadata(tables[0])

# Reassign index and column names
if objs[0]._data.multiindex:
Expand Down
Loading