Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 80 additions & 67 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
if TYPE_CHECKING:

class ExtensionArraySupportsAnyAll("ExtensionArray"):

def any(self, *, skipna: bool = True) -> bool:
pass

Expand All @@ -98,7 +99,6 @@ def all(self, *, skipna: bool = True) -> bool:
NumpyValueArrayLike,
)


_extension_array_shared_docs: dict[str, str] = {}

ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
Expand Down Expand Up @@ -242,7 +242,11 @@ class ExtensionArray:
# ------------------------------------------------------------------------

@classmethod
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
def _from_sequence(cls,
scalars,
*,
dtype: Dtype | None = None,
copy=False):
"""
Construct a new ExtensionArray from a sequence of scalars.

Expand All @@ -264,9 +268,11 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
raise AbstractMethodError(cls)

@classmethod
def _from_sequence_of_strings(
cls, strings, *, dtype: Dtype | None = None, copy=False
):
def _from_sequence_of_strings(cls,
strings,
*,
dtype: Dtype | None = None,
copy=False):
"""
Construct a new ExtensionArray from a sequence of strings.

Expand Down Expand Up @@ -314,12 +320,12 @@ def __getitem__(self, item: ScalarIndexer) -> Any:
...

@overload
def __getitem__(self: ExtensionArrayT, item: SequenceIndexer) -> ExtensionArrayT:
def __getitem__(self: ExtensionArrayT,
item: SequenceIndexer) -> ExtensionArrayT:
...

def __getitem__(
self: ExtensionArrayT, item: PositionalIndexer
) -> ExtensionArrayT | Any:
def __getitem__(self: ExtensionArrayT,
item: PositionalIndexer) -> ExtensionArrayT | Any:
"""
Select a subset of self.

Expand Down Expand Up @@ -395,7 +401,8 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
# __init__ method coerces that value, then so should __setitem__
# Note, also, that Series/DataFrame.where internally use __setitem__
# on a copy of the data.
raise NotImplementedError(f"{type(self)} does not implement __setitem__.")
raise NotImplementedError(
f"{type(self)} does not implement __setitem__.")

def __len__(self) -> int:
"""
Expand Down Expand Up @@ -427,7 +434,8 @@ def __contains__(self, item: object) -> bool | np.bool_:
if is_scalar(item) and isna(item):
if not self._can_hold_na:
return False
elif item is self.dtype.na_value or isinstance(item, self.dtype.type):
elif item is self.dtype.na_value or isinstance(
item, self.dtype.type):
return self._hasna
else:
return False
Expand Down Expand Up @@ -510,7 +518,7 @@ def shape(self) -> Shape:
"""
Return a tuple of the array dimensions.
"""
return (len(self),)
return (len(self), )

@property
def size(self) -> int:
Expand Down Expand Up @@ -544,7 +552,9 @@ def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
def astype(self,
dtype: ExtensionDtype,
copy: bool = ...) -> ExtensionArray:
...

@overload
Expand Down Expand Up @@ -785,7 +795,9 @@ def fillna(
# error: Argument 2 to "check_value_size" has incompatible type
# "ExtensionArray"; expected "ndarray"
value = missing.check_value_size(
value, mask, len(self) # type: ignore[arg-type]
value,
mask,
len(self) # type: ignore[arg-type]
)

if mask.any():
Expand Down Expand Up @@ -813,7 +825,9 @@ def dropna(self: ExtensionArrayT) -> ExtensionArrayT:
# error: Unsupported operand type for ~ ("ExtensionArray")
return self[~self.isna()] # type: ignore[operator]

def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
def shift(self,
periods: int = 1,
fill_value: object = None) -> ExtensionArray:
"""
Shift values by desired number.

Expand Down Expand Up @@ -852,14 +866,14 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
if isna(fill_value):
fill_value = self.dtype.na_value

empty = self._from_sequence(
[fill_value] * min(abs(periods), len(self)), dtype=self.dtype
)
empty = self._from_sequence([fill_value] *
min(abs(periods), len(self)),
dtype=self.dtype)
if periods > 0:
a = empty
b = self[:-periods]
else:
a = self[abs(periods) :]
a = self[abs(periods):]
b = empty
return self._concat_same_type([a, b])

Expand Down Expand Up @@ -1002,7 +1016,8 @@ def _values_for_factorize(self) -> tuple[np.ndarray, Any]:
"""
return self.astype(object), np.nan

def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
def factorize(self,
na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
"""
Encode the extension array as an enumerated type.

Expand Down Expand Up @@ -1043,16 +1058,14 @@ def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
# Complete control over factorization.
arr, na_value = self._values_for_factorize()

codes, uniques = factorize_array(
arr, na_sentinel=na_sentinel, na_value=na_value
)
codes, uniques = factorize_array(arr,
na_sentinel=na_sentinel,
na_value=na_value)

uniques_ea = self._from_factorized(uniques, self)
return codes, uniques_ea

_extension_array_shared_docs[
"repeat"
] = """
_extension_array_shared_docs["repeat"] = """
Repeat elements of a %(klass)s.

Returns a new %(klass)s where each element of the current %(klass)s
Expand Down Expand Up @@ -1245,9 +1258,9 @@ def __repr__(self) -> str:
# the short repr has no trailing newline, while the truncated
# repr does. So we include a newline in our template, and strip
# any trailing newlines from format_object_summary
data = format_object_summary(
self, self._formatter(), indent_for_name=False
).rstrip(", \n")
data = format_object_summary(self,
self._formatter(),
indent_for_name=False).rstrip(", \n")
class_name = f"<{type(self).__name__}>\n"
return f"{class_name}{data}\nLength: {len(self)}, dtype: {self.dtype}"

Expand All @@ -1258,9 +1271,8 @@ def _repr_2d(self) -> str:
# repr does. So we include a newline in our template, and strip
# any trailing newlines from format_object_summary
lines = [
format_object_summary(x, self._formatter(), indent_for_name=False).rstrip(
", \n"
)
format_object_summary(x, self._formatter(),
indent_for_name=False).rstrip(", \n")
for x in self
]
data = ",\n".join(lines)
Expand Down Expand Up @@ -1312,7 +1324,9 @@ def transpose(self, *axes: int) -> ExtensionArray:
def T(self) -> ExtensionArray:
return self.transpose()

def ravel(self, order: Literal["C", "F", "A", "K"] | None = "C") -> ExtensionArray:
def ravel(
self,
order: Literal["C", "F", "A", "K"] | None = "C") -> ExtensionArray:
"""
Return a flattened view on this array.

Expand All @@ -1333,8 +1347,8 @@ def ravel(self, order: Literal["C", "F", "A", "K"] | None = "C") -> ExtensionArr

@classmethod
def _concat_same_type(
cls: type[ExtensionArrayT], to_concat: Sequence[ExtensionArrayT]
) -> ExtensionArrayT:
cls: type[ExtensionArrayT],
to_concat: Sequence[ExtensionArrayT]) -> ExtensionArrayT:
"""
Concatenate multiple array of this dtype.

Expand Down Expand Up @@ -1388,10 +1402,8 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
"""
meth = getattr(self, name, None)
if meth is None:
raise TypeError(
f"'{type(self).__name__}' with dtype {self.dtype} "
f"does not support reduction '{name}'"
)
raise TypeError(f"'{type(self).__name__}' with dtype {self.dtype} "
f"does not support reduction '{name}'")
return meth(skipna=skipna, **kwargs)

# https://github.com/python/typeshed/issues/2148#issuecomment-520783318
Expand Down Expand Up @@ -1419,7 +1431,8 @@ def tolist(self) -> list:
return [x.tolist() for x in self]
return list(self)

def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT:
def delete(self: ExtensionArrayT,
loc: PositionalIndexer) -> ExtensionArrayT:
indexer = np.delete(np.arange(len(self)), loc)
return self.take(indexer)

Expand Down Expand Up @@ -1478,9 +1491,8 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:

self[mask] = val

def _where(
self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value
) -> ExtensionArrayT:
def _where(self: ExtensionArrayT, mask: npt.NDArray[np.bool_],
value) -> ExtensionArrayT:
"""
Analogue to np.where(mask, self, value)

Expand All @@ -1503,9 +1515,8 @@ def _where(
result[~mask] = val
return result

def _fill_mask_inplace(
self, method: str, limit, mask: npt.NDArray[np.bool_]
) -> None:
def _fill_mask_inplace(self, method: str, limit,
mask: npt.NDArray[np.bool_]) -> None:
"""
Replace values in locations specified by 'mask' using pad or backfill.

Expand Down Expand Up @@ -1571,9 +1582,8 @@ def _empty(cls, shape: Shape, dtype: ExtensionDtype):
)
return result

def _quantile(
self: ExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
) -> ExtensionArrayT:
def _quantile(self: ExtensionArrayT, qs: npt.NDArray[np.float64],
interpolation: str) -> ExtensionArrayT:
"""
Compute the quantiles of self for each quantile in `qs`.

Expand All @@ -1593,7 +1603,8 @@ def _quantile(
arr = np.atleast_2d(np.asarray(self))
fill_value = np.nan

res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
res_values = quantile_with_mask(arr, mask, fill_value, qs,
interpolation)

if self.ndim == 2:
# i.e. DatetimeArray
Expand Down Expand Up @@ -1628,29 +1639,27 @@ def _mode(self: ExtensionArrayT, dropna: bool = True) -> ExtensionArrayT:

def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
if any(
isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame)) for other in inputs
):
isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame))
for other in inputs):
return NotImplemented

result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
self, ufunc, method, *inputs, **kwargs
)
self, ufunc, method, *inputs, **kwargs)
if result is not NotImplemented:
return result

if "out" in kwargs:
return arraylike.dispatch_ufunc_with_out(
self, ufunc, method, *inputs, **kwargs
)
return arraylike.dispatch_ufunc_with_out(self, ufunc, method,
*inputs, **kwargs)

if method == "reduce":
result = arraylike.dispatch_reduction_ufunc(
self, ufunc, method, *inputs, **kwargs
)
self, ufunc, method, *inputs, **kwargs)
if result is not NotImplemented:
return result

return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
return arraylike.default_array_ufunc(self, ufunc, method, *inputs,
**kwargs)


class ExtensionOpsMixin:
Expand Down Expand Up @@ -1680,14 +1689,17 @@ def _add_arithmetic_ops(cls):
setattr(cls, "__rpow__", cls._create_arithmetic_method(roperator.rpow))
setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod))
setattr(cls, "__rmod__", cls._create_arithmetic_method(roperator.rmod))
setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv))
setattr(
cls, "__rfloordiv__", cls._create_arithmetic_method(roperator.rfloordiv)
)
setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv))
setattr(cls, "__rtruediv__", cls._create_arithmetic_method(roperator.rtruediv))
setattr(cls, "__floordiv__",
cls._create_arithmetic_method(operator.floordiv))
setattr(cls, "__rfloordiv__",
cls._create_arithmetic_method(roperator.rfloordiv))
setattr(cls, "__truediv__",
cls._create_arithmetic_method(operator.truediv))
setattr(cls, "__rtruediv__",
cls._create_arithmetic_method(roperator.rtruediv))
setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod))
setattr(cls, "__rdivmod__", cls._create_arithmetic_method(roperator.rdivmod))
setattr(cls, "__rdivmod__",
cls._create_arithmetic_method(roperator.rdivmod))

@classmethod
def _create_comparison_method(cls, op):
Expand Down Expand Up @@ -1783,6 +1795,7 @@ def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
"""

def _binop(self, other):

def convert_values(param):
if isinstance(param, ExtensionArray) or is_list_like(param):
ovalues = param
Expand Down
Loading